From 10f21b21ae2945817943646c47aca7e8affb953f Mon Sep 17 00:00:00 2001 From: Kenneth Moreland Date: Thu, 6 Jan 2022 07:15:38 -0700 Subject: [PATCH] Pre-allocate arrays for MergePartitionedDataSet The initial implementation of `MergePartitionedDataSet` would grow each array as it was generated. As each partition was revisited, the arrays being merged would be reallocated and data appended to the end. Although this works, it is slower than necessary. Each reallocation has to copy the previously saved data into the newly allocated memory space. This new implementation first counts how big each array should be and then copies data from each partition into the appropriate location of each dataset. Also changed the templating of how fields are copied so that all field types are supported, not just those in the common types. --- vtkm/cont/MergePartitionedDataSet.cxx | 555 +++++++++++++++++--------- 1 file changed, 376 insertions(+), 179 deletions(-) diff --git a/vtkm/cont/MergePartitionedDataSet.cxx b/vtkm/cont/MergePartitionedDataSet.cxx index 896e7a6a4..34a7a91ad 100644 --- a/vtkm/cont/MergePartitionedDataSet.cxx +++ b/vtkm/cont/MergePartitionedDataSet.cxx @@ -10,105 +10,378 @@ #include #include +#include #include +#include #include #include +#include #include #include -#include + +#include +#include + +namespace +{ + +void CountPointsAndCells(const vtkm::cont::PartitionedDataSet& partitionedDataSet, + vtkm::Id& numPointsTotal, + vtkm::Id& numCellsTotal) +{ + numPointsTotal = 0; + numCellsTotal = 0; + + for (vtkm::Id partitionId = 0; partitionId < partitionedDataSet.GetNumberOfPartitions(); + ++partitionId) + { + vtkm::cont::DataSet partition = partitionedDataSet.GetPartition(partitionId); + numPointsTotal += partition.GetNumberOfPoints(); + numCellsTotal += partition.GetNumberOfCells(); + } +} + +struct PassCellShapesNumIndices : vtkm::worklet::WorkletVisitCellsWithPoints +{ + using ControlSignature = void(CellSetIn inputTopology, FieldOut shapes, FieldOut numIndices); + using ExecutionSignature = void(CellShape, PointCount, _2, _3); + + template + VTKM_EXEC void operator()(const CellShape& inShape, + vtkm::IdComponent inNumIndices, + vtkm::UInt8& outShape, + vtkm::IdComponent& outNumIndices) const + { + outShape = inShape.Id; + outNumIndices = inNumIndices; + } +}; + +void MergeShapes(const vtkm::cont::PartitionedDataSet& partitionedDataSet, + vtkm::Id numCellsTotal, + vtkm::cont::ArrayHandle& shapes, + vtkm::cont::ArrayHandle& numIndices) +{ + vtkm::cont::Invoker invoke; + + shapes.Allocate(numCellsTotal); + numIndices.Allocate(numCellsTotal); + + vtkm::Id cellStartIndex = 0; + for (vtkm::Id partitionId = 0; partitionId < partitionedDataSet.GetNumberOfPartitions(); + ++partitionId) + { + vtkm::cont::DataSet partition = partitionedDataSet.GetPartition(partitionId); + vtkm::Id numCellsPartition = partition.GetNumberOfCells(); + + auto shapesView = vtkm::cont::make_ArrayHandleView(shapes, cellStartIndex, numCellsPartition); + auto numIndicesView = + vtkm::cont::make_ArrayHandleView(numIndices, cellStartIndex, numCellsPartition); + + invoke(PassCellShapesNumIndices{}, partition.GetCellSet(), shapesView, numIndicesView); + + cellStartIndex += numCellsPartition; + } + VTKM_ASSERT(cellStartIndex == numCellsTotal); +} + +struct PassCellIndices : vtkm::worklet::WorkletVisitCellsWithPoints +{ + using ControlSignature = void(CellSetIn inputTopology, FieldOut pointIndices); + using ExecutionSignature = void(PointIndices, _2); + + vtkm::Id IndexOffset; + + PassCellIndices(vtkm::Id indexOffset) + : IndexOffset(indexOffset) + { + } + + template + VTKM_EXEC void operator()(const InPointIndexType& inPoints, OutPointIndexType& outPoints) const + { + vtkm::IdComponent numPoints = inPoints.GetNumberOfComponents(); + VTKM_ASSERT(numPoints == outPoints.GetNumberOfComponents()); + for (vtkm::IdComponent pointIndex = 0; pointIndex < numPoints; pointIndex++) + { + outPoints[pointIndex] = inPoints[pointIndex] + this->IndexOffset; + } + } +}; + +void MergeIndices(const vtkm::cont::PartitionedDataSet& partitionedDataSet, + const vtkm::cont::ArrayHandle& offsets, + vtkm::Id numIndicesTotal, + vtkm::cont::ArrayHandle& indices) +{ + vtkm::cont::Invoker invoke; + + indices.Allocate(numIndicesTotal); + + vtkm::Id pointStartIndex = 0; + vtkm::Id cellStartIndex = 0; + for (vtkm::Id partitionId = 0; partitionId < partitionedDataSet.GetNumberOfPartitions(); + ++partitionId) + { + vtkm::cont::DataSet partition = partitionedDataSet.GetPartition(partitionId); + vtkm::Id numCellsPartition = partition.GetNumberOfCells(); + + auto offsetsView = + vtkm::cont::make_ArrayHandleView(offsets, cellStartIndex, numCellsPartition + 1); + auto indicesGroupView = vtkm::cont::make_ArrayHandleGroupVecVariable(indices, offsetsView); + + invoke(PassCellIndices{ pointStartIndex }, partition.GetCellSet(), indicesGroupView); + + pointStartIndex += partition.GetNumberOfPoints(); + cellStartIndex += numCellsPartition; + } + VTKM_ASSERT(cellStartIndex == (offsets.GetNumberOfValues() - 1)); +} + +vtkm::cont::CellSetExplicit<> MergeCellSets( + const vtkm::cont::PartitionedDataSet& partitionedDataSet, + vtkm::Id numPointsTotal, + vtkm::Id numCellsTotal) +{ + vtkm::cont::ArrayHandle shapes; + vtkm::cont::ArrayHandle numIndices; + MergeShapes(partitionedDataSet, numCellsTotal, shapes, numIndices); + + vtkm::cont::ArrayHandle offsets; + vtkm::Id numIndicesTotal; + vtkm::cont::ConvertNumComponentsToOffsets(numIndices, offsets, numIndicesTotal); + numIndices.ReleaseResources(); + + vtkm::cont::ArrayHandle indices; + MergeIndices(partitionedDataSet, offsets, numIndicesTotal, indices); + + vtkm::cont::CellSetExplicit<> outCells; + outCells.Fill(numPointsTotal, shapes, indices, offsets); + return outCells; +} + +struct ClearPartitionWorklet : vtkm::worklet::WorkletMapField +{ + using ControlSignature = void(FieldIn indices, WholeArrayInOut array); + using ExecutionSignature = void(WorkIndex, _2); + + vtkm::Id IndexOffset; + + ClearPartitionWorklet(vtkm::Id indexOffset) + : IndexOffset(indexOffset) + { + } + + template + VTKM_EXEC void operator()(vtkm::Id index, OutPortalType& outPortal) const + { + // It's weird to get a value from a portal only to override it, but the expect type + // is weird (a variable-sized Vec), so this is the only practical way to set it. + auto outVec = outPortal.Get(index + this->IndexOffset); + for (vtkm::IdComponent comp = 0; comp < outVec.GetNumberOfComponents(); ++comp) + { + outVec[comp] = 0; + } + // Shouldn't really do anything. + outPortal.Set(index + this->IndexOffset, outVec); + } +}; + +template +void ClearPartition(OutArrayHandle& outArray, vtkm::Id startIndex, vtkm::Id numValues) +{ + vtkm::cont::Invoker invoke; + invoke(ClearPartitionWorklet{ startIndex }, vtkm::cont::ArrayHandleIndex(numValues), outArray); +} + +struct CopyPartitionWorklet : vtkm::worklet::WorkletMapField +{ + using ControlSignature = void(FieldIn sourceArray, WholeArrayInOut mergedArray); + using ExecutionSignature = void(WorkIndex, _1, _2); + + vtkm::Id IndexOffset; + + CopyPartitionWorklet(vtkm::Id indexOffset) + : IndexOffset(indexOffset) + { + } + + template + VTKM_EXEC void operator()(vtkm::Id index, const InVecType& inVec, OutPortalType& outPortal) const + { + // It's weird to get a value from a portal only to override it, but the expect type + // is weird (a variable-sized Vec), so this is the only practical way to set it. + auto outVec = outPortal.Get(index + this->IndexOffset); + VTKM_ASSERT(inVec.GetNumberOfComponents() == outVec.GetNumberOfComponents()); + for (vtkm::IdComponent comp = 0; comp < outVec.GetNumberOfComponents(); ++comp) + { + outVec[comp] = static_cast(inVec[comp]); + } + // Shouldn't really do anything. + outPortal.Set(index + this->IndexOffset, outVec); + } +}; + +template +void CopyPartition(const vtkm::cont::Field& inField, OutArrayHandle& outArray, vtkm::Id startIndex) +{ + vtkm::cont::Invoker invoke; + using ComponentType = typename OutArrayHandle::ValueType::ComponentType; + if (inField.GetData().IsBaseComponentType()) + { + invoke(CopyPartitionWorklet{ startIndex }, + inField.GetData().ExtractArrayFromComponents(), + outArray); + } + else + { + VTKM_LOG_S(vtkm::cont::LogLevel::Info, + "Discovered mismatched types for field " << inField.GetName() + << ". Requires extra copy."); + invoke(CopyPartitionWorklet{ startIndex }, + inField.GetDataAsDefaultFloat().ExtractArrayFromComponents(), + outArray); + } +} + +template +vtkm::cont::UnknownArrayHandle MergeArray(vtkm::Id numPartitions, + HasFieldFunctor&& hasField, + GetFieldFunctor&& getField, + vtkm::Id totalSize) +{ + vtkm::cont::UnknownArrayHandle mergedArray = getField(0).GetData().NewInstanceBasic(); + mergedArray.Allocate(totalSize); + + vtkm::Id startIndex = 0; + for (vtkm::Id partitionId = 0; partitionId < numPartitions; ++partitionId) + { + vtkm::Id partitionSize; + if (hasField(partitionId, partitionSize)) + { + vtkm::cont::Field sourceField = getField(partitionId); + mergedArray.CastAndCallWithExtractedArray( + [=](auto array) { CopyPartition(sourceField, array, startIndex); }); + } + else + { + mergedArray.CastAndCallWithExtractedArray( + [=](auto array) { ClearPartition(array, startIndex, partitionSize); }); + } + startIndex += partitionSize; + } + VTKM_ASSERT(startIndex == totalSize); + + return mergedArray; +} + +vtkm::cont::CoordinateSystem MergeCoordinateSystem( + const vtkm::cont::PartitionedDataSet& partitionedDataSet, + vtkm::IdComponent coordId, + vtkm::Id numPointsTotal) +{ + std::string coordName = partitionedDataSet.GetPartition(0).GetCoordinateSystem(coordId).GetName(); + auto hasField = [&](vtkm::Id partitionId, vtkm::Id& partitionSize) -> bool { + vtkm::cont::DataSet partition = partitionedDataSet.GetPartition(partitionId); + partitionSize = partition.GetNumberOfPoints(); + // Should partitions match coordinates on name or coordinate id? They both should match, but + // for now let's go by id and check the name. + if (partition.GetNumberOfCoordinateSystems() <= coordId) + { + VTKM_LOG_S(vtkm::cont::LogLevel::Warn, + "When merging partitions, partition " + << partitionId << " is missing coordinate system with index " << coordId); + return false; + } + if (partition.GetCoordinateSystem(coordId).GetName() != coordName) + { + VTKM_LOG_S(vtkm::cont::LogLevel::Warn, + "When merging partitions, partition " + << partitionId << " reported a coordinate system with name '" + << partition.GetCoordinateSystem(coordId).GetName() + << "' instead of expected name '" << coordName << "'"); + } + return true; + }; + auto getField = [&](vtkm::Id partitionId) -> vtkm::cont::Field { + return partitionedDataSet.GetPartition(partitionId).GetCoordinateSystem(coordId); + }; + vtkm::cont::UnknownArrayHandle mergedArray = + MergeArray(partitionedDataSet.GetNumberOfPartitions(), hasField, getField, numPointsTotal); + return vtkm::cont::CoordinateSystem{ coordName, mergedArray }; +} + +vtkm::cont::Field MergeField(const vtkm::cont::PartitionedDataSet& partitionedDataSet, + vtkm::IdComponent fieldId, + vtkm::Id numPointsTotal, + vtkm::Id numCellsTotal) +{ + vtkm::cont::Field referenceField = partitionedDataSet.GetPartition(0).GetField(fieldId); + vtkm::Id totalSize = 0; + switch (referenceField.GetAssociation()) + { + case vtkm::cont::Field::Association::POINTS: + totalSize = numPointsTotal; + break; + case vtkm::cont::Field::Association::CELL_SET: + totalSize = numCellsTotal; + break; + default: + VTKM_LOG_S(vtkm::cont::LogLevel::Info, + "Skipping merge of field '" << referenceField.GetName() + << "' because it has an unsupported association."); + return referenceField; + } + + auto hasField = [&](vtkm::Id partitionId, vtkm::Id& partitionSize) -> bool { + vtkm::cont::DataSet partition = partitionedDataSet.GetPartition(partitionId); + if (partition.HasField(referenceField.GetName(), referenceField.GetAssociation())) + { + partitionSize = partition.GetField(referenceField.GetName(), referenceField.GetAssociation()) + .GetData() + .GetNumberOfValues(); + return true; + } + else + { + VTKM_LOG_S(vtkm::cont::LogLevel::Info, + "Partition " << partitionId << " does not have field " + << referenceField.GetName()); + switch (referenceField.GetAssociation()) + { + case vtkm::cont::Field::Association::POINTS: + partitionSize = partition.GetNumberOfPoints(); + break; + case vtkm::cont::Field::Association::CELL_SET: + partitionSize = partition.GetNumberOfCells(); + break; + default: + partitionSize = 0; + break; + } + return false; + } + }; + auto getField = [&](vtkm::Id partitionId) -> vtkm::cont::Field { + return partitionedDataSet.GetPartition(partitionId) + .GetField(referenceField.GetName(), referenceField.GetAssociation()); + }; + vtkm::cont::UnknownArrayHandle mergedArray = + MergeArray(partitionedDataSet.GetNumberOfPartitions(), hasField, getField, totalSize); + return vtkm::cont::Field{ referenceField.GetName(), + referenceField.GetAssociation(), + mergedArray }; +} + +} // anonymous namespace + +//----------------------------------------------------------------------------- namespace vtkm { namespace cont { -struct TransferCellsFunctor -{ - template - VTKM_CONT void operator()(const T& cellSetIn, - vtkm::cont::ArrayHandle& shapes, - vtkm::cont::ArrayHandle& numIndices, - vtkm::cont::ArrayHandle& connectivity, - vtkm::Id pointStartIndex) const - { - // allocate shapes and numIndices - vtkm::Id cellStartIndex = shapes.GetNumberOfValues(); - shapes.Allocate(cellStartIndex + cellSetIn.GetNumberOfCells(), vtkm::CopyFlag::On); - numIndices.Allocate(cellStartIndex + cellSetIn.GetNumberOfCells(), vtkm::CopyFlag::On); - - // fill the view of numIndices - vtkm::cont::ArrayHandleView> viewArrayNumIndices( - numIndices, cellStartIndex, cellSetIn.GetNumberOfCells()); - vtkm::cont::Invoker invoke; - invoke(vtkm::worklet::CellDeepCopy::CountCellPoints{}, cellSetIn, viewArrayNumIndices); - - // convert numIndices to offsets and derive numberOfConnectivity - vtkm::cont::ArrayHandle offsets; - vtkm::Id numberOfConnectivity; - vtkm::cont::ConvertNumComponentsToOffsets(viewArrayNumIndices, offsets, numberOfConnectivity); - - // allocate connectivity - vtkm::Id connectivityStartIndex = connectivity.GetNumberOfValues(); - connectivity.Allocate(connectivityStartIndex + numberOfConnectivity, vtkm::CopyFlag::On); - - // fill the view of shapes and connectivity - vtkm::cont::ArrayHandleView> viewArrayShapes( - shapes, cellStartIndex, cellSetIn.GetNumberOfCells()); - vtkm::cont::ArrayHandleView> viewArrayConnectivity( - connectivity, connectivityStartIndex, numberOfConnectivity); - invoke(vtkm::worklet::CellDeepCopy::PassCellStructure{}, - cellSetIn, - viewArrayShapes, - vtkm::cont::make_ArrayHandleGroupVecVariable(viewArrayConnectivity, offsets)); - shapes.ReleaseResourcesExecution(); - offsets.ReleaseResourcesExecution(); - connectivity.ReleaseResourcesExecution(); - - // point the connectivity to the point indices of this partition - vtkm::cont::Algorithm::Transform( - vtkm::cont::ArrayHandleConstant(pointStartIndex, numberOfConnectivity), - viewArrayConnectivity, - viewArrayConnectivity, - vtkm::Sum()); - } -}; - -void TransferCells(const vtkm::cont::UnknownCellSet& cellSetIn, - vtkm::cont::ArrayHandle& shapes, - vtkm::cont::ArrayHandle& numIndices, - vtkm::cont::ArrayHandle& connectivity, - vtkm::Id startIndex) -{ - vtkm::cont::CastAndCall( - cellSetIn, TransferCellsFunctor{}, shapes, numIndices, connectivity, startIndex); -} - -struct TransferArrayFunctor -{ - template - VTKM_CONT void operator()(const vtkm::cont::ArrayHandle& arrayIn, - vtkm::cont::UnknownArrayHandle& arrayOut, - vtkm::Id startIndex) const - { - vtkm::cont::ArrayHandleView> viewArrayOut( - arrayOut.AsArrayHandle>(), - startIndex, - arrayIn.GetNumberOfValues()); - vtkm::cont::ArrayCopy(arrayIn, viewArrayOut); - } -}; - -void TransferArray(const vtkm::cont::UnknownArrayHandle& arrayIn, - vtkm::cont::UnknownArrayHandle& arrayOut, - vtkm::Id startIndex) -{ - arrayIn.CastAndCallForTypes< - VTKM_DEFAULT_TYPE_LIST, - vtkm::List>( - TransferArrayFunctor{}, arrayOut, startIndex); -} - -//----------------------------------------------------------------------------- VTKM_CONT vtkm::cont::DataSet MergePartitionedDataSet( const vtkm::cont::PartitionedDataSet& partitionedDataSet) @@ -116,103 +389,27 @@ vtkm::cont::DataSet MergePartitionedDataSet( // verify correctnees of data VTKM_ASSERT(partitionedDataSet.GetNumberOfPartitions() > 0); - vtkm::cont::UnknownArrayHandle coordsOut; - vtkm::cont::ArrayCopy( - partitionedDataSet.GetPartition(0).GetCoordinateSystem().GetDataAsMultiplexer(), coordsOut); - vtkm::cont::ArrayHandle shapes; - vtkm::cont::ArrayHandle numIndices; - vtkm::cont::ArrayHandle connectivity; - vtkm::Id numberOfPointsSoFar = 0; - for (vtkm::Id partitionId = 0; partitionId < partitionedDataSet.GetNumberOfPartitions(); - partitionId++) + vtkm::Id numPointsTotal; + vtkm::Id numCellsTotal; + CountPointsAndCells(partitionedDataSet, numPointsTotal, numCellsTotal); + + vtkm::cont::DataSet outputData; + outputData.SetCellSet(MergeCellSets(partitionedDataSet, numPointsTotal, numCellsTotal)); + + vtkm::cont::DataSet partition0 = partitionedDataSet.GetPartition(0); + for (vtkm::IdComponent coordId = 0; coordId < partition0.GetNumberOfCoordinateSystems(); + ++coordId) { - auto partition = partitionedDataSet.GetPartition(partitionId); - - // Transfer points - auto coordsIn = partition.GetCoordinateSystem().GetDataAsMultiplexer(); - coordsOut.Allocate(numberOfPointsSoFar + partition.GetNumberOfPoints(), vtkm::CopyFlag::On); - TransferArray(coordsIn, coordsOut, numberOfPointsSoFar); - - // Transfer cells - vtkm::cont::UnknownCellSet cellset; - cellset = partition.GetCellSet(); - TransferCells(cellset, shapes, numIndices, connectivity, numberOfPointsSoFar); - - numberOfPointsSoFar += partition.GetNumberOfPoints(); + outputData.AddCoordinateSystem( + MergeCoordinateSystem(partitionedDataSet, coordId, numPointsTotal)); } - // create dataset - vtkm::cont::CellSetExplicit<> cellSet; - vtkm::Id nPts = static_cast(coordsOut.GetNumberOfValues()); - vtkm::cont::ArrayHandle offsets; - vtkm::cont::Algorithm::ScanExtended(numIndices, offsets); - cellSet.Fill(nPts, shapes, connectivity, offsets); - vtkm::cont::DataSet derivedDataSet; - derivedDataSet.AddCoordinateSystem(vtkm::cont::CoordinateSystem( - partitionedDataSet.GetPartition(0).GetCoordinateSystem().GetName(), coordsOut)); - derivedDataSet.SetCellSet(cellSet); - - // Transfer fields - for (vtkm::IdComponent f = 0; f < partitionedDataSet.GetPartition(0).GetNumberOfFields(); f++) + for (vtkm::IdComponent fieldId = 0; fieldId < partition0.GetNumberOfFields(); ++fieldId) { - std::string name = partitionedDataSet.GetPartition(0).GetField(f).GetName(); - vtkm::cont::UnknownArrayHandle outFieldHandle; - vtkm::cont::ArrayCopy(partitionedDataSet.GetPartition(0).GetField(name).GetData(), - outFieldHandle); - - if (partitionedDataSet.GetPartition(0).GetField(name).IsFieldCell()) - { - outFieldHandle.Allocate(derivedDataSet.GetNumberOfCells()); - vtkm::Id numberOfCellValuesSoFar = 0; - for (vtkm::Id partitionId = 0; partitionId < partitionedDataSet.GetNumberOfPartitions(); - partitionId++) - { - try - { - auto cellField = partitionedDataSet.GetPartition(partitionId).GetField(name).GetData(); - TransferArray(cellField, outFieldHandle, numberOfCellValuesSoFar); - } - catch (const vtkm::cont::Error& error) - { - std::cout << "Partition 0 contains an array that partition " << partitionId - << " does not contain. The merged Dataset will have random values where values " - "were missing." - << std::endl; - std::cout << error.GetMessage() << std::endl; - } - numberOfCellValuesSoFar += partitionedDataSet.GetPartition(partitionId).GetNumberOfCells(); - } - derivedDataSet.AddCellField(name, outFieldHandle); - } - else - { - outFieldHandle.Allocate(derivedDataSet.GetNumberOfPoints()); - vtkm::Id numberOfPointValuesSoFar = 0; - for (vtkm::Id partitionId = 0; partitionId < partitionedDataSet.GetNumberOfPartitions(); - partitionId++) - { - try - { - auto pointField = partitionedDataSet.GetPartition(partitionId).GetField(name).GetData(); - TransferArray(pointField, outFieldHandle, numberOfPointValuesSoFar); - } - // catch (vtkm::cont::ErrorBadValue& error) - catch (const vtkm::cont::Error& error) - { - std::cout << "Partition 0 contains an array that partition " << partitionId - << " does not contain. The merged Dataset will have random values where values " - "were missing." - << std::endl; - std::cout << error.GetMessage() << std::endl; - } - numberOfPointValuesSoFar += - partitionedDataSet.GetPartition(partitionId).GetNumberOfPoints(); - } - derivedDataSet.AddPointField(name, outFieldHandle); - } + outputData.AddField(MergeField(partitionedDataSet, fieldId, numPointsTotal, numCellsTotal)); } - return derivedDataSet; + return outputData; } }