//============================================================================ // Copyright (c) Kitware, Inc. // All rights reserved. // See LICENSE.txt for details. // // This software is distributed WITHOUT ANY WARRANTY; without even // the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR // PURPOSE. See the above copyright notice for more information. //============================================================================ #ifndef vtk_m_worklet_KdTree3DConstruction_h #define vtk_m_worklet_KdTree3DConstruction_h #include #include #include #include #include #include #include #include #include #include #include #include #include namespace vtkm { namespace worklet { namespace spatialstructure { class KdTree3DConstruction { public: ////////// General WORKLET for Kd-tree ////// class ComputeFlag : public vtkm::worklet::WorkletMapField { public: using ControlSignature = void(FieldIn rank, FieldIn pointCountInSeg, FieldOut flag); using ExecutionSignature = void(_1, _2, _3); VTKM_CONT ComputeFlag() {} template VTKM_EXEC void operator()(const T& rank, const T& pointCountInSeg, T& flag) const { if (static_cast(rank) >= static_cast(pointCountInSeg) / 2.0f) flag = 1; //right subtree else flag = 0; //left subtree } }; class InverseArray : public vtkm::worklet::WorkletMapField { //only for 0/1 array public: using ControlSignature = void(FieldIn in, FieldOut out); using ExecutionSignature = void(_1, _2); VTKM_CONT InverseArray() {} template VTKM_EXEC void operator()(const T& in, T& out) const { if (in == 0) out = 1; else out = 0; } }; class SegmentedSplitTransform : public vtkm::worklet::WorkletMapField { public: using ControlSignature = void(FieldIn B, FieldIn D, FieldIn F, FieldIn G, FieldIn H, FieldOut I); using ExecutionSignature = void(_1, _2, _3, _4, _5, _6); VTKM_CONT SegmentedSplitTransform() {} template VTKM_EXEC void operator()(const T& B, const T& D, const T& F, const T& G, const T& H, T& I) const { if (B == 1) { I = F + H + D; } else { I = F + G - 1; } } }; class ScatterArray : public vtkm::worklet::WorkletMapField { public: using ControlSignature = void(FieldIn in, FieldIn index, WholeArrayOut out); using ExecutionSignature = void(_1, _2, _3); VTKM_CONT ScatterArray() {} template VTKM_EXEC void operator()(const T& in, const T& index, const OutputArrayPortalType& outputPortal) const { outputPortal.Set(index, in); } }; class NewSegmentId : public vtkm::worklet::WorkletMapField { public: using ControlSignature = void(FieldIn inSegmentId, FieldIn flag, FieldOut outSegmentId); using ExecutionSignature = void(_1, _2, _3); VTKM_CONT NewSegmentId() {} template VTKM_EXEC void operator()(const T& oldSegId, const T& flag, T& newSegId) const { if (flag == 0) newSegId = oldSegId * 2; else newSegId = oldSegId * 2 + 1; } }; class SaveSplitPointId : public vtkm::worklet::WorkletMapField { public: using ControlSignature = void(FieldIn pointId, FieldIn flag, FieldIn oldSplitPointId, FieldOut newSplitPointId); using ExecutionSignature = void(_1, _2, _3, _4); VTKM_CONT SaveSplitPointId() {} template VTKM_EXEC void operator()(const T& pointId, const T& flag, const T& oldSplitPointId, T& newSplitPointId) const { if (flag == 0) //do not change newSplitPointId = oldSplitPointId; else //split point id newSplitPointId = pointId; } }; class FindSplitPointId : public vtkm::worklet::WorkletMapField { public: using ControlSignature = void(FieldIn pointId, FieldIn rank, FieldOut splitIdInsegment); using ExecutionSignature = void(_1, _2, _3); VTKM_CONT FindSplitPointId() {} template VTKM_EXEC void operator()(const T& pointId, const T& rank, T& splitIdInsegment) const { if (rank == 0) //do not change splitIdInsegment = pointId; else //split point id splitIdInsegment = -1; //indicate this is not split point } }; class ArrayAdd : public vtkm::worklet::WorkletMapField { public: using ControlSignature = void(FieldIn inArray0, FieldIn inArray1, FieldOut outArray); using ExecutionSignature = void(_1, _2, _3); VTKM_CONT ArrayAdd() {} template VTKM_EXEC void operator()(const T& in0, const T& in1, T& out) const { out = in0 + in1; } }; class SeprateVec3AryHandle : public vtkm::worklet::WorkletMapField { public: using ControlSignature = void(FieldIn inVec3, FieldOut out0, FieldOut out1, FieldOut out2); using ExecutionSignature = void(_1, _2, _3, _4); VTKM_CONT SeprateVec3AryHandle() {} template VTKM_EXEC void operator()(const Vec& inVec3, T& out0, T& out1, T& out2) const { out0 = inVec3[0]; out1 = inVec3[1]; out2 = inVec3[2]; } }; ////////// General worklet WRAPPER for Kd-tree ////// template vtkm::cont::ArrayHandle ReverseScanInclusiveByKey(vtkm::cont::ArrayHandle& keyHandle, vtkm::cont::ArrayHandle& dataHandle, BinaryFunctor binary_functor) { using Algorithm = vtkm::cont::Algorithm; vtkm::cont::ArrayHandle resultHandle; auto reversedResultHandle = vtkm::cont::make_ArrayHandleReverse(resultHandle); Algorithm::ScanInclusiveByKey(vtkm::cont::make_ArrayHandleReverse(keyHandle), vtkm::cont::make_ArrayHandleReverse(dataHandle), reversedResultHandle, binary_functor); return resultHandle; } template vtkm::cont::ArrayHandle Inverse01ArrayWrapper(vtkm::cont::ArrayHandle& inputHandle) { vtkm::cont::ArrayHandle InverseHandle; InverseArray invWorklet; vtkm::worklet::DispatcherMapField inverseArrayDispatcher(invWorklet); inverseArrayDispatcher.Invoke(inputHandle, InverseHandle); return InverseHandle; } template vtkm::cont::ArrayHandle ScatterArrayWrapper(vtkm::cont::ArrayHandle& inputHandle, vtkm::cont::ArrayHandle& indexHandle) { vtkm::cont::ArrayHandle outputHandle; outputHandle.Allocate(inputHandle.GetNumberOfValues()); ScatterArray scatterWorklet; vtkm::worklet::DispatcherMapField scatterArrayDispatcher(scatterWorklet); scatterArrayDispatcher.Invoke(inputHandle, indexHandle, outputHandle); return outputHandle; } template vtkm::cont::ArrayHandle NewKeyWrapper(vtkm::cont::ArrayHandle& oldSegIdHandle, vtkm::cont::ArrayHandle& flagHandle) { vtkm::cont::ArrayHandle newSegIdHandle; NewSegmentId newsegidWorklet; vtkm::worklet::DispatcherMapField newSegIdDispatcher(newsegidWorklet); newSegIdDispatcher.Invoke(oldSegIdHandle, flagHandle, newSegIdHandle); return newSegIdHandle; } template vtkm::cont::ArrayHandle SaveSplitPointIdWrapper(vtkm::cont::ArrayHandle& pointIdHandle, vtkm::cont::ArrayHandle& flagHandle, vtkm::cont::ArrayHandle& rankHandle, vtkm::cont::ArrayHandle& oldSplitIdHandle) { vtkm::cont::ArrayHandle splitIdInSegmentHandle; FindSplitPointId findSplitPointIdWorklet; vtkm::worklet::DispatcherMapField findSplitPointIdWorkletDispatcher( findSplitPointIdWorklet); findSplitPointIdWorkletDispatcher.Invoke(pointIdHandle, rankHandle, splitIdInSegmentHandle); vtkm::cont::ArrayHandle splitIdInSegmentByScanHandle = ReverseScanInclusiveByKey(flagHandle, splitIdInSegmentHandle, vtkm::Maximum()); vtkm::cont::ArrayHandle splitIdHandle; SaveSplitPointId saveSplitPointIdWorklet; vtkm::worklet::DispatcherMapField saveSplitPointIdWorkletDispatcher( saveSplitPointIdWorklet); saveSplitPointIdWorkletDispatcher.Invoke( splitIdInSegmentByScanHandle, flagHandle, oldSplitIdHandle, splitIdHandle); return splitIdHandle; } template vtkm::cont::ArrayHandle ArrayAddWrapper(vtkm::cont::ArrayHandle& array0Handle, vtkm::cont::ArrayHandle& array1Handle) { vtkm::cont::ArrayHandle resultHandle; ArrayAdd arrayAddWorklet; vtkm::worklet::DispatcherMapField arrayAddDispatcher(arrayAddWorklet); arrayAddDispatcher.Invoke(array0Handle, array1Handle, resultHandle); return resultHandle; } /////////////////////////////////////////////////// ////////General Kd tree function ////////////////// /////////////////////////////////////////////////// template vtkm::cont::ArrayHandle ComputeFlagProcedure(vtkm::cont::ArrayHandle& rankHandle, vtkm::cont::ArrayHandle& segIdHandle) { using Algorithm = vtkm::cont::Algorithm; vtkm::cont::ArrayHandle segCountAryHandle; { vtkm::cont::ArrayHandle tmpAryHandle; vtkm::cont::ArrayHandleConstant constHandle(1, rankHandle.GetNumberOfValues()); Algorithm::ScanInclusiveByKey( segIdHandle, constHandle, tmpAryHandle, vtkm::Add()); //compute ttl segs in segment segCountAryHandle = ReverseScanInclusiveByKey(segIdHandle, tmpAryHandle, vtkm::Maximum()); } vtkm::cont::ArrayHandle flagHandle; vtkm::worklet::DispatcherMapField computeFlagDispatcher; computeFlagDispatcher.Invoke(rankHandle, segCountAryHandle, flagHandle); return flagHandle; } template vtkm::cont::ArrayHandle SegmentedSplitProcedure(vtkm::cont::ArrayHandle& A_Handle, vtkm::cont::ArrayHandle& B_Handle, vtkm::cont::ArrayHandle& C_Handle) { using Algorithm = vtkm::cont::Algorithm; vtkm::cont::ArrayHandle D_Handle; T initValue = 0; Algorithm::ScanExclusiveByKey(C_Handle, B_Handle, D_Handle, initValue, vtkm::Add()); vtkm::cont::ArrayHandleCounting Ecouting_Handle(0, 1, A_Handle.GetNumberOfValues()); vtkm::cont::ArrayHandle E_Handle; Algorithm::Copy(Ecouting_Handle, E_Handle); vtkm::cont::ArrayHandle F_Handle; Algorithm::ScanInclusiveByKey(C_Handle, E_Handle, F_Handle, vtkm::Minimum()); vtkm::cont::ArrayHandle InvB_Handle = Inverse01ArrayWrapper(B_Handle); vtkm::cont::ArrayHandle G_Handle; Algorithm::ScanInclusiveByKey(C_Handle, InvB_Handle, G_Handle, vtkm::Add()); vtkm::cont::ArrayHandle H_Handle = ReverseScanInclusiveByKey(C_Handle, G_Handle, vtkm::Maximum()); vtkm::cont::ArrayHandle I_Handle; SegmentedSplitTransform sstWorklet; vtkm::worklet::DispatcherMapField segmentedSplitTransformDispatcher( sstWorklet); segmentedSplitTransformDispatcher.Invoke( B_Handle, D_Handle, F_Handle, G_Handle, H_Handle, I_Handle); return ScatterArrayWrapper(A_Handle, I_Handle); } template void RenumberRanksProcedure(vtkm::cont::ArrayHandle& A_Handle, vtkm::cont::ArrayHandle& B_Handle, vtkm::cont::ArrayHandle& C_Handle, vtkm::cont::ArrayHandle& D_Handle) { using Algorithm = vtkm::cont::Algorithm; vtkm::Id nPoints = A_Handle.GetNumberOfValues(); vtkm::cont::ArrayHandleCounting Ecouting_Handle(0, 1, nPoints); vtkm::cont::ArrayHandle E_Handle; Algorithm::Copy(Ecouting_Handle, E_Handle); vtkm::cont::ArrayHandle F_Handle; Algorithm::ScanInclusiveByKey(D_Handle, E_Handle, F_Handle, vtkm::Minimum()); vtkm::cont::ArrayHandle G_Handle; G_Handle = ArrayAddWrapper(A_Handle, F_Handle); vtkm::cont::ArrayHandleConstant HConstant_Handle(1, nPoints); vtkm::cont::ArrayHandle H_Handle; Algorithm::Copy(HConstant_Handle, H_Handle); vtkm::cont::ArrayHandle I_Handle; T initValue = 0; Algorithm::ScanExclusiveByKey(C_Handle, H_Handle, I_Handle, initValue, vtkm::Add()); vtkm::cont::ArrayHandle J_Handle; J_Handle = ScatterArrayWrapper(I_Handle, G_Handle); vtkm::cont::ArrayHandle K_Handle; K_Handle = ScatterArrayWrapper(B_Handle, G_Handle); vtkm::cont::ArrayHandle L_Handle; L_Handle = SegmentedSplitProcedure(J_Handle, K_Handle, D_Handle); vtkm::cont::ArrayHandle M_Handle; Algorithm::ScanInclusiveByKey(C_Handle, E_Handle, M_Handle, vtkm::Minimum()); vtkm::cont::ArrayHandle N_Handle; N_Handle = ArrayAddWrapper(L_Handle, M_Handle); A_Handle = ScatterArrayWrapper(I_Handle, N_Handle); } /////////////3D construction ///////////////////// /// \brief Segmented split for 3D x, y, z coordinates /// /// Split \c pointId_Handle, \c X_Handle, \c Y_Handle and \c Z_Handle within each segment /// as indicated by \c segId_Handle according to flags in \c flag_Handle. /// /// \tparam T /// \param pointId_Handle /// \param flag_Handle /// \param segId_Handle /// \param X_Handle /// \param Y_Handle /// \param Z_Handle template void SegmentedSplitProcedure3D(vtkm::cont::ArrayHandle& pointId_Handle, vtkm::cont::ArrayHandle& flag_Handle, vtkm::cont::ArrayHandle& segId_Handle, vtkm::cont::ArrayHandle& X_Handle, vtkm::cont::ArrayHandle& Y_Handle, vtkm::cont::ArrayHandle& Z_Handle) { using Algorithm = vtkm::cont::Algorithm; vtkm::cont::ArrayHandle D_Handle; T initValue = 0; Algorithm::ScanExclusiveByKey(segId_Handle, flag_Handle, D_Handle, initValue, vtkm::Add()); vtkm::cont::ArrayHandleCounting Ecouting_Handle(0, 1, pointId_Handle.GetNumberOfValues()); vtkm::cont::ArrayHandle E_Handle; Algorithm::Copy(Ecouting_Handle, E_Handle); vtkm::cont::ArrayHandle F_Handle; Algorithm::ScanInclusiveByKey(segId_Handle, E_Handle, F_Handle, vtkm::Minimum()); vtkm::cont::ArrayHandle InvB_Handle = Inverse01ArrayWrapper(flag_Handle); vtkm::cont::ArrayHandle G_Handle; Algorithm::ScanInclusiveByKey(segId_Handle, InvB_Handle, G_Handle, vtkm::Add()); vtkm::cont::ArrayHandle H_Handle = ReverseScanInclusiveByKey(segId_Handle, G_Handle, vtkm::Maximum()); vtkm::cont::ArrayHandle I_Handle; SegmentedSplitTransform sstWorklet; vtkm::worklet::DispatcherMapField segmentedSplitTransformDispatcher( sstWorklet); segmentedSplitTransformDispatcher.Invoke( flag_Handle, D_Handle, F_Handle, G_Handle, H_Handle, I_Handle); pointId_Handle = ScatterArrayWrapper(pointId_Handle, I_Handle); flag_Handle = ScatterArrayWrapper(flag_Handle, I_Handle); X_Handle = ScatterArrayWrapper(X_Handle, I_Handle); Y_Handle = ScatterArrayWrapper(Y_Handle, I_Handle); Z_Handle = ScatterArrayWrapper(Z_Handle, I_Handle); } /// \brief Perform one level of KD-Tree construction /// /// Construct a level of KD-Tree by segemeted splits (partitioning) of \c pointId_Handle, /// \c xrank_Handle, \c yrank_Handle and \c zrank_Handle according to the medium element /// in each segment as indicated by \c segId_Handle alone the axis determined by \c level. /// The split point of each segment will be updated in \c splitId_Handle. template void OneLevelSplit3D(vtkm::cont::ArrayHandle& pointId_Handle, vtkm::cont::ArrayHandle& xrank_Handle, vtkm::cont::ArrayHandle& yrank_Handle, vtkm::cont::ArrayHandle& zrank_Handle, vtkm::cont::ArrayHandle& segId_Handle, vtkm::cont::ArrayHandle& splitId_Handle, vtkm::Int32 level) { using Algorithm = vtkm::cont::Algorithm; vtkm::cont::ArrayHandle flag_Handle; if (level % 3 == 0) { flag_Handle = ComputeFlagProcedure(xrank_Handle, segId_Handle); } else if (level % 3 == 1) { flag_Handle = ComputeFlagProcedure(yrank_Handle, segId_Handle); } else { flag_Handle = ComputeFlagProcedure(zrank_Handle, segId_Handle); } SegmentedSplitProcedure3D( pointId_Handle, flag_Handle, segId_Handle, xrank_Handle, yrank_Handle, zrank_Handle); vtkm::cont::ArrayHandle segIdOld_Handle; Algorithm::Copy(segId_Handle, segIdOld_Handle); segId_Handle = NewKeyWrapper(segIdOld_Handle, flag_Handle); RenumberRanksProcedure(xrank_Handle, flag_Handle, segId_Handle, segIdOld_Handle); RenumberRanksProcedure(yrank_Handle, flag_Handle, segId_Handle, segIdOld_Handle); RenumberRanksProcedure(zrank_Handle, flag_Handle, segId_Handle, segIdOld_Handle); if (level % 3 == 0) { splitId_Handle = SaveSplitPointIdWrapper(pointId_Handle, flag_Handle, xrank_Handle, splitId_Handle); } else if (level % 3 == 1) { splitId_Handle = SaveSplitPointIdWrapper(pointId_Handle, flag_Handle, yrank_Handle, splitId_Handle); } else { splitId_Handle = SaveSplitPointIdWrapper(pointId_Handle, flag_Handle, zrank_Handle, splitId_Handle); } } /// \brief Construct KdTree from x y z coordinate vector. /// /// This method constructs an array based KD-Tree from x, y, z coordinates of points in \c /// coordi_Handle. The method rotates between x, y and z axis and splits input points into /// equal halves with respect to the split axis at each level of construction. The indices to /// the leaf nodes are returned in \c pointId_Handle and indices to internal nodes (splits) /// are returned in splitId_handle. /// /// \param coordi_Handle (in) x, y, z coordinates of input points /// \param pointId_Handle (out) returns indices to leaf nodes of the KD-tree /// \param splitId_Handle (out) returns indices to internal nodes of the KD-tree // Leaf Node vector and internal node (split) vectpr template void Run(const vtkm::cont::ArrayHandle, CoordStorageTag>& coordi_Handle, vtkm::cont::ArrayHandle& pointId_Handle, vtkm::cont::ArrayHandle& splitId_Handle) { using Algorithm = vtkm::cont::Algorithm; vtkm::Id nTrainingPoints = coordi_Handle.GetNumberOfValues(); vtkm::cont::ArrayHandleCounting counting_Handle(0, 1, nTrainingPoints); Algorithm::Copy(counting_Handle, pointId_Handle); vtkm::cont::ArrayHandle xorder_Handle; Algorithm::Copy(counting_Handle, xorder_Handle); vtkm::cont::ArrayHandle yorder_Handle; Algorithm::Copy(counting_Handle, yorder_Handle); vtkm::cont::ArrayHandle zorder_Handle; Algorithm::Copy(counting_Handle, zorder_Handle); splitId_Handle.Allocate(nTrainingPoints); vtkm::cont::ArrayHandle xcoordi_Handle; vtkm::cont::ArrayHandle ycoordi_Handle; vtkm::cont::ArrayHandle zcoordi_Handle; SeprateVec3AryHandle sepVec3Worklet; vtkm::worklet::DispatcherMapField sepVec3Dispatcher(sepVec3Worklet); sepVec3Dispatcher.Invoke(coordi_Handle, xcoordi_Handle, ycoordi_Handle, zcoordi_Handle); Algorithm::SortByKey(xcoordi_Handle, xorder_Handle); vtkm::cont::ArrayHandle xrank_Handle = ScatterArrayWrapper(pointId_Handle, xorder_Handle); Algorithm::SortByKey(ycoordi_Handle, yorder_Handle); vtkm::cont::ArrayHandle yrank_Handle = ScatterArrayWrapper(pointId_Handle, yorder_Handle); Algorithm::SortByKey(zcoordi_Handle, zorder_Handle); vtkm::cont::ArrayHandle zrank_Handle = ScatterArrayWrapper(pointId_Handle, zorder_Handle); vtkm::cont::ArrayHandle segId_Handle; vtkm::cont::ArrayHandleConstant constHandle(0, nTrainingPoints); Algorithm::Copy(constHandle, segId_Handle); ///// build kd tree ///// vtkm::Int32 maxLevel = static_cast(ceil(vtkm::Log2(nTrainingPoints) + 1)); for (vtkm::Int32 i = 0; i < maxLevel - 1; i++) { OneLevelSplit3D( pointId_Handle, xrank_Handle, yrank_Handle, zrank_Handle, segId_Handle, splitId_Handle, i); } } }; } } } // namespace vtkm::worklet #endif // vtk_m_worklet_KdTree3DConstruction_h