diff --git a/vtkm/cont/MergePartitionedDataSet.cxx b/vtkm/cont/MergePartitionedDataSet.cxx index 896e7a6a4..34a7a91ad 100644 --- a/vtkm/cont/MergePartitionedDataSet.cxx +++ b/vtkm/cont/MergePartitionedDataSet.cxx @@ -10,105 +10,378 @@ #include #include +#include #include +#include #include #include +#include #include #include -#include + +#include +#include + +namespace +{ + +void CountPointsAndCells(const vtkm::cont::PartitionedDataSet& partitionedDataSet, + vtkm::Id& numPointsTotal, + vtkm::Id& numCellsTotal) +{ + numPointsTotal = 0; + numCellsTotal = 0; + + for (vtkm::Id partitionId = 0; partitionId < partitionedDataSet.GetNumberOfPartitions(); + ++partitionId) + { + vtkm::cont::DataSet partition = partitionedDataSet.GetPartition(partitionId); + numPointsTotal += partition.GetNumberOfPoints(); + numCellsTotal += partition.GetNumberOfCells(); + } +} + +struct PassCellShapesNumIndices : vtkm::worklet::WorkletVisitCellsWithPoints +{ + using ControlSignature = void(CellSetIn inputTopology, FieldOut shapes, FieldOut numIndices); + using ExecutionSignature = void(CellShape, PointCount, _2, _3); + + template + VTKM_EXEC void operator()(const CellShape& inShape, + vtkm::IdComponent inNumIndices, + vtkm::UInt8& outShape, + vtkm::IdComponent& outNumIndices) const + { + outShape = inShape.Id; + outNumIndices = inNumIndices; + } +}; + +void MergeShapes(const vtkm::cont::PartitionedDataSet& partitionedDataSet, + vtkm::Id numCellsTotal, + vtkm::cont::ArrayHandle& shapes, + vtkm::cont::ArrayHandle& numIndices) +{ + vtkm::cont::Invoker invoke; + + shapes.Allocate(numCellsTotal); + numIndices.Allocate(numCellsTotal); + + vtkm::Id cellStartIndex = 0; + for (vtkm::Id partitionId = 0; partitionId < partitionedDataSet.GetNumberOfPartitions(); + ++partitionId) + { + vtkm::cont::DataSet partition = partitionedDataSet.GetPartition(partitionId); + vtkm::Id numCellsPartition = partition.GetNumberOfCells(); + + auto shapesView = vtkm::cont::make_ArrayHandleView(shapes, cellStartIndex, numCellsPartition); + auto numIndicesView = + vtkm::cont::make_ArrayHandleView(numIndices, cellStartIndex, numCellsPartition); + + invoke(PassCellShapesNumIndices{}, partition.GetCellSet(), shapesView, numIndicesView); + + cellStartIndex += numCellsPartition; + } + VTKM_ASSERT(cellStartIndex == numCellsTotal); +} + +struct PassCellIndices : vtkm::worklet::WorkletVisitCellsWithPoints +{ + using ControlSignature = void(CellSetIn inputTopology, FieldOut pointIndices); + using ExecutionSignature = void(PointIndices, _2); + + vtkm::Id IndexOffset; + + PassCellIndices(vtkm::Id indexOffset) + : IndexOffset(indexOffset) + { + } + + template + VTKM_EXEC void operator()(const InPointIndexType& inPoints, OutPointIndexType& outPoints) const + { + vtkm::IdComponent numPoints = inPoints.GetNumberOfComponents(); + VTKM_ASSERT(numPoints == outPoints.GetNumberOfComponents()); + for (vtkm::IdComponent pointIndex = 0; pointIndex < numPoints; pointIndex++) + { + outPoints[pointIndex] = inPoints[pointIndex] + this->IndexOffset; + } + } +}; + +void MergeIndices(const vtkm::cont::PartitionedDataSet& partitionedDataSet, + const vtkm::cont::ArrayHandle& offsets, + vtkm::Id numIndicesTotal, + vtkm::cont::ArrayHandle& indices) +{ + vtkm::cont::Invoker invoke; + + indices.Allocate(numIndicesTotal); + + vtkm::Id pointStartIndex = 0; + vtkm::Id cellStartIndex = 0; + for (vtkm::Id partitionId = 0; partitionId < partitionedDataSet.GetNumberOfPartitions(); + ++partitionId) + { + vtkm::cont::DataSet partition = partitionedDataSet.GetPartition(partitionId); + vtkm::Id numCellsPartition = partition.GetNumberOfCells(); + + auto offsetsView = + vtkm::cont::make_ArrayHandleView(offsets, cellStartIndex, numCellsPartition + 1); + auto indicesGroupView = vtkm::cont::make_ArrayHandleGroupVecVariable(indices, offsetsView); + + invoke(PassCellIndices{ pointStartIndex }, partition.GetCellSet(), indicesGroupView); + + pointStartIndex += partition.GetNumberOfPoints(); + cellStartIndex += numCellsPartition; + } + VTKM_ASSERT(cellStartIndex == (offsets.GetNumberOfValues() - 1)); +} + +vtkm::cont::CellSetExplicit<> MergeCellSets( + const vtkm::cont::PartitionedDataSet& partitionedDataSet, + vtkm::Id numPointsTotal, + vtkm::Id numCellsTotal) +{ + vtkm::cont::ArrayHandle shapes; + vtkm::cont::ArrayHandle numIndices; + MergeShapes(partitionedDataSet, numCellsTotal, shapes, numIndices); + + vtkm::cont::ArrayHandle offsets; + vtkm::Id numIndicesTotal; + vtkm::cont::ConvertNumComponentsToOffsets(numIndices, offsets, numIndicesTotal); + numIndices.ReleaseResources(); + + vtkm::cont::ArrayHandle indices; + MergeIndices(partitionedDataSet, offsets, numIndicesTotal, indices); + + vtkm::cont::CellSetExplicit<> outCells; + outCells.Fill(numPointsTotal, shapes, indices, offsets); + return outCells; +} + +struct ClearPartitionWorklet : vtkm::worklet::WorkletMapField +{ + using ControlSignature = void(FieldIn indices, WholeArrayInOut array); + using ExecutionSignature = void(WorkIndex, _2); + + vtkm::Id IndexOffset; + + ClearPartitionWorklet(vtkm::Id indexOffset) + : IndexOffset(indexOffset) + { + } + + template + VTKM_EXEC void operator()(vtkm::Id index, OutPortalType& outPortal) const + { + // It's weird to get a value from a portal only to override it, but the expect type + // is weird (a variable-sized Vec), so this is the only practical way to set it. + auto outVec = outPortal.Get(index + this->IndexOffset); + for (vtkm::IdComponent comp = 0; comp < outVec.GetNumberOfComponents(); ++comp) + { + outVec[comp] = 0; + } + // Shouldn't really do anything. + outPortal.Set(index + this->IndexOffset, outVec); + } +}; + +template +void ClearPartition(OutArrayHandle& outArray, vtkm::Id startIndex, vtkm::Id numValues) +{ + vtkm::cont::Invoker invoke; + invoke(ClearPartitionWorklet{ startIndex }, vtkm::cont::ArrayHandleIndex(numValues), outArray); +} + +struct CopyPartitionWorklet : vtkm::worklet::WorkletMapField +{ + using ControlSignature = void(FieldIn sourceArray, WholeArrayInOut mergedArray); + using ExecutionSignature = void(WorkIndex, _1, _2); + + vtkm::Id IndexOffset; + + CopyPartitionWorklet(vtkm::Id indexOffset) + : IndexOffset(indexOffset) + { + } + + template + VTKM_EXEC void operator()(vtkm::Id index, const InVecType& inVec, OutPortalType& outPortal) const + { + // It's weird to get a value from a portal only to override it, but the expect type + // is weird (a variable-sized Vec), so this is the only practical way to set it. + auto outVec = outPortal.Get(index + this->IndexOffset); + VTKM_ASSERT(inVec.GetNumberOfComponents() == outVec.GetNumberOfComponents()); + for (vtkm::IdComponent comp = 0; comp < outVec.GetNumberOfComponents(); ++comp) + { + outVec[comp] = static_cast(inVec[comp]); + } + // Shouldn't really do anything. + outPortal.Set(index + this->IndexOffset, outVec); + } +}; + +template +void CopyPartition(const vtkm::cont::Field& inField, OutArrayHandle& outArray, vtkm::Id startIndex) +{ + vtkm::cont::Invoker invoke; + using ComponentType = typename OutArrayHandle::ValueType::ComponentType; + if (inField.GetData().IsBaseComponentType()) + { + invoke(CopyPartitionWorklet{ startIndex }, + inField.GetData().ExtractArrayFromComponents(), + outArray); + } + else + { + VTKM_LOG_S(vtkm::cont::LogLevel::Info, + "Discovered mismatched types for field " << inField.GetName() + << ". Requires extra copy."); + invoke(CopyPartitionWorklet{ startIndex }, + inField.GetDataAsDefaultFloat().ExtractArrayFromComponents(), + outArray); + } +} + +template +vtkm::cont::UnknownArrayHandle MergeArray(vtkm::Id numPartitions, + HasFieldFunctor&& hasField, + GetFieldFunctor&& getField, + vtkm::Id totalSize) +{ + vtkm::cont::UnknownArrayHandle mergedArray = getField(0).GetData().NewInstanceBasic(); + mergedArray.Allocate(totalSize); + + vtkm::Id startIndex = 0; + for (vtkm::Id partitionId = 0; partitionId < numPartitions; ++partitionId) + { + vtkm::Id partitionSize; + if (hasField(partitionId, partitionSize)) + { + vtkm::cont::Field sourceField = getField(partitionId); + mergedArray.CastAndCallWithExtractedArray( + [=](auto array) { CopyPartition(sourceField, array, startIndex); }); + } + else + { + mergedArray.CastAndCallWithExtractedArray( + [=](auto array) { ClearPartition(array, startIndex, partitionSize); }); + } + startIndex += partitionSize; + } + VTKM_ASSERT(startIndex == totalSize); + + return mergedArray; +} + +vtkm::cont::CoordinateSystem MergeCoordinateSystem( + const vtkm::cont::PartitionedDataSet& partitionedDataSet, + vtkm::IdComponent coordId, + vtkm::Id numPointsTotal) +{ + std::string coordName = partitionedDataSet.GetPartition(0).GetCoordinateSystem(coordId).GetName(); + auto hasField = [&](vtkm::Id partitionId, vtkm::Id& partitionSize) -> bool { + vtkm::cont::DataSet partition = partitionedDataSet.GetPartition(partitionId); + partitionSize = partition.GetNumberOfPoints(); + // Should partitions match coordinates on name or coordinate id? They both should match, but + // for now let's go by id and check the name. + if (partition.GetNumberOfCoordinateSystems() <= coordId) + { + VTKM_LOG_S(vtkm::cont::LogLevel::Warn, + "When merging partitions, partition " + << partitionId << " is missing coordinate system with index " << coordId); + return false; + } + if (partition.GetCoordinateSystem(coordId).GetName() != coordName) + { + VTKM_LOG_S(vtkm::cont::LogLevel::Warn, + "When merging partitions, partition " + << partitionId << " reported a coordinate system with name '" + << partition.GetCoordinateSystem(coordId).GetName() + << "' instead of expected name '" << coordName << "'"); + } + return true; + }; + auto getField = [&](vtkm::Id partitionId) -> vtkm::cont::Field { + return partitionedDataSet.GetPartition(partitionId).GetCoordinateSystem(coordId); + }; + vtkm::cont::UnknownArrayHandle mergedArray = + MergeArray(partitionedDataSet.GetNumberOfPartitions(), hasField, getField, numPointsTotal); + return vtkm::cont::CoordinateSystem{ coordName, mergedArray }; +} + +vtkm::cont::Field MergeField(const vtkm::cont::PartitionedDataSet& partitionedDataSet, + vtkm::IdComponent fieldId, + vtkm::Id numPointsTotal, + vtkm::Id numCellsTotal) +{ + vtkm::cont::Field referenceField = partitionedDataSet.GetPartition(0).GetField(fieldId); + vtkm::Id totalSize = 0; + switch (referenceField.GetAssociation()) + { + case vtkm::cont::Field::Association::POINTS: + totalSize = numPointsTotal; + break; + case vtkm::cont::Field::Association::CELL_SET: + totalSize = numCellsTotal; + break; + default: + VTKM_LOG_S(vtkm::cont::LogLevel::Info, + "Skipping merge of field '" << referenceField.GetName() + << "' because it has an unsupported association."); + return referenceField; + } + + auto hasField = [&](vtkm::Id partitionId, vtkm::Id& partitionSize) -> bool { + vtkm::cont::DataSet partition = partitionedDataSet.GetPartition(partitionId); + if (partition.HasField(referenceField.GetName(), referenceField.GetAssociation())) + { + partitionSize = partition.GetField(referenceField.GetName(), referenceField.GetAssociation()) + .GetData() + .GetNumberOfValues(); + return true; + } + else + { + VTKM_LOG_S(vtkm::cont::LogLevel::Info, + "Partition " << partitionId << " does not have field " + << referenceField.GetName()); + switch (referenceField.GetAssociation()) + { + case vtkm::cont::Field::Association::POINTS: + partitionSize = partition.GetNumberOfPoints(); + break; + case vtkm::cont::Field::Association::CELL_SET: + partitionSize = partition.GetNumberOfCells(); + break; + default: + partitionSize = 0; + break; + } + return false; + } + }; + auto getField = [&](vtkm::Id partitionId) -> vtkm::cont::Field { + return partitionedDataSet.GetPartition(partitionId) + .GetField(referenceField.GetName(), referenceField.GetAssociation()); + }; + vtkm::cont::UnknownArrayHandle mergedArray = + MergeArray(partitionedDataSet.GetNumberOfPartitions(), hasField, getField, totalSize); + return vtkm::cont::Field{ referenceField.GetName(), + referenceField.GetAssociation(), + mergedArray }; +} + +} // anonymous namespace + +//----------------------------------------------------------------------------- namespace vtkm { namespace cont { -struct TransferCellsFunctor -{ - template - VTKM_CONT void operator()(const T& cellSetIn, - vtkm::cont::ArrayHandle& shapes, - vtkm::cont::ArrayHandle& numIndices, - vtkm::cont::ArrayHandle& connectivity, - vtkm::Id pointStartIndex) const - { - // allocate shapes and numIndices - vtkm::Id cellStartIndex = shapes.GetNumberOfValues(); - shapes.Allocate(cellStartIndex + cellSetIn.GetNumberOfCells(), vtkm::CopyFlag::On); - numIndices.Allocate(cellStartIndex + cellSetIn.GetNumberOfCells(), vtkm::CopyFlag::On); - - // fill the view of numIndices - vtkm::cont::ArrayHandleView> viewArrayNumIndices( - numIndices, cellStartIndex, cellSetIn.GetNumberOfCells()); - vtkm::cont::Invoker invoke; - invoke(vtkm::worklet::CellDeepCopy::CountCellPoints{}, cellSetIn, viewArrayNumIndices); - - // convert numIndices to offsets and derive numberOfConnectivity - vtkm::cont::ArrayHandle offsets; - vtkm::Id numberOfConnectivity; - vtkm::cont::ConvertNumComponentsToOffsets(viewArrayNumIndices, offsets, numberOfConnectivity); - - // allocate connectivity - vtkm::Id connectivityStartIndex = connectivity.GetNumberOfValues(); - connectivity.Allocate(connectivityStartIndex + numberOfConnectivity, vtkm::CopyFlag::On); - - // fill the view of shapes and connectivity - vtkm::cont::ArrayHandleView> viewArrayShapes( - shapes, cellStartIndex, cellSetIn.GetNumberOfCells()); - vtkm::cont::ArrayHandleView> viewArrayConnectivity( - connectivity, connectivityStartIndex, numberOfConnectivity); - invoke(vtkm::worklet::CellDeepCopy::PassCellStructure{}, - cellSetIn, - viewArrayShapes, - vtkm::cont::make_ArrayHandleGroupVecVariable(viewArrayConnectivity, offsets)); - shapes.ReleaseResourcesExecution(); - offsets.ReleaseResourcesExecution(); - connectivity.ReleaseResourcesExecution(); - - // point the connectivity to the point indices of this partition - vtkm::cont::Algorithm::Transform( - vtkm::cont::ArrayHandleConstant(pointStartIndex, numberOfConnectivity), - viewArrayConnectivity, - viewArrayConnectivity, - vtkm::Sum()); - } -}; - -void TransferCells(const vtkm::cont::UnknownCellSet& cellSetIn, - vtkm::cont::ArrayHandle& shapes, - vtkm::cont::ArrayHandle& numIndices, - vtkm::cont::ArrayHandle& connectivity, - vtkm::Id startIndex) -{ - vtkm::cont::CastAndCall( - cellSetIn, TransferCellsFunctor{}, shapes, numIndices, connectivity, startIndex); -} - -struct TransferArrayFunctor -{ - template - VTKM_CONT void operator()(const vtkm::cont::ArrayHandle& arrayIn, - vtkm::cont::UnknownArrayHandle& arrayOut, - vtkm::Id startIndex) const - { - vtkm::cont::ArrayHandleView> viewArrayOut( - arrayOut.AsArrayHandle>(), - startIndex, - arrayIn.GetNumberOfValues()); - vtkm::cont::ArrayCopy(arrayIn, viewArrayOut); - } -}; - -void TransferArray(const vtkm::cont::UnknownArrayHandle& arrayIn, - vtkm::cont::UnknownArrayHandle& arrayOut, - vtkm::Id startIndex) -{ - arrayIn.CastAndCallForTypes< - VTKM_DEFAULT_TYPE_LIST, - vtkm::List>( - TransferArrayFunctor{}, arrayOut, startIndex); -} - -//----------------------------------------------------------------------------- VTKM_CONT vtkm::cont::DataSet MergePartitionedDataSet( const vtkm::cont::PartitionedDataSet& partitionedDataSet) @@ -116,103 +389,27 @@ vtkm::cont::DataSet MergePartitionedDataSet( // verify correctnees of data VTKM_ASSERT(partitionedDataSet.GetNumberOfPartitions() > 0); - vtkm::cont::UnknownArrayHandle coordsOut; - vtkm::cont::ArrayCopy( - partitionedDataSet.GetPartition(0).GetCoordinateSystem().GetDataAsMultiplexer(), coordsOut); - vtkm::cont::ArrayHandle shapes; - vtkm::cont::ArrayHandle numIndices; - vtkm::cont::ArrayHandle connectivity; - vtkm::Id numberOfPointsSoFar = 0; - for (vtkm::Id partitionId = 0; partitionId < partitionedDataSet.GetNumberOfPartitions(); - partitionId++) + vtkm::Id numPointsTotal; + vtkm::Id numCellsTotal; + CountPointsAndCells(partitionedDataSet, numPointsTotal, numCellsTotal); + + vtkm::cont::DataSet outputData; + outputData.SetCellSet(MergeCellSets(partitionedDataSet, numPointsTotal, numCellsTotal)); + + vtkm::cont::DataSet partition0 = partitionedDataSet.GetPartition(0); + for (vtkm::IdComponent coordId = 0; coordId < partition0.GetNumberOfCoordinateSystems(); + ++coordId) { - auto partition = partitionedDataSet.GetPartition(partitionId); - - // Transfer points - auto coordsIn = partition.GetCoordinateSystem().GetDataAsMultiplexer(); - coordsOut.Allocate(numberOfPointsSoFar + partition.GetNumberOfPoints(), vtkm::CopyFlag::On); - TransferArray(coordsIn, coordsOut, numberOfPointsSoFar); - - // Transfer cells - vtkm::cont::UnknownCellSet cellset; - cellset = partition.GetCellSet(); - TransferCells(cellset, shapes, numIndices, connectivity, numberOfPointsSoFar); - - numberOfPointsSoFar += partition.GetNumberOfPoints(); + outputData.AddCoordinateSystem( + MergeCoordinateSystem(partitionedDataSet, coordId, numPointsTotal)); } - // create dataset - vtkm::cont::CellSetExplicit<> cellSet; - vtkm::Id nPts = static_cast(coordsOut.GetNumberOfValues()); - vtkm::cont::ArrayHandle offsets; - vtkm::cont::Algorithm::ScanExtended(numIndices, offsets); - cellSet.Fill(nPts, shapes, connectivity, offsets); - vtkm::cont::DataSet derivedDataSet; - derivedDataSet.AddCoordinateSystem(vtkm::cont::CoordinateSystem( - partitionedDataSet.GetPartition(0).GetCoordinateSystem().GetName(), coordsOut)); - derivedDataSet.SetCellSet(cellSet); - - // Transfer fields - for (vtkm::IdComponent f = 0; f < partitionedDataSet.GetPartition(0).GetNumberOfFields(); f++) + for (vtkm::IdComponent fieldId = 0; fieldId < partition0.GetNumberOfFields(); ++fieldId) { - std::string name = partitionedDataSet.GetPartition(0).GetField(f).GetName(); - vtkm::cont::UnknownArrayHandle outFieldHandle; - vtkm::cont::ArrayCopy(partitionedDataSet.GetPartition(0).GetField(name).GetData(), - outFieldHandle); - - if (partitionedDataSet.GetPartition(0).GetField(name).IsFieldCell()) - { - outFieldHandle.Allocate(derivedDataSet.GetNumberOfCells()); - vtkm::Id numberOfCellValuesSoFar = 0; - for (vtkm::Id partitionId = 0; partitionId < partitionedDataSet.GetNumberOfPartitions(); - partitionId++) - { - try - { - auto cellField = partitionedDataSet.GetPartition(partitionId).GetField(name).GetData(); - TransferArray(cellField, outFieldHandle, numberOfCellValuesSoFar); - } - catch (const vtkm::cont::Error& error) - { - std::cout << "Partition 0 contains an array that partition " << partitionId - << " does not contain. The merged Dataset will have random values where values " - "were missing." - << std::endl; - std::cout << error.GetMessage() << std::endl; - } - numberOfCellValuesSoFar += partitionedDataSet.GetPartition(partitionId).GetNumberOfCells(); - } - derivedDataSet.AddCellField(name, outFieldHandle); - } - else - { - outFieldHandle.Allocate(derivedDataSet.GetNumberOfPoints()); - vtkm::Id numberOfPointValuesSoFar = 0; - for (vtkm::Id partitionId = 0; partitionId < partitionedDataSet.GetNumberOfPartitions(); - partitionId++) - { - try - { - auto pointField = partitionedDataSet.GetPartition(partitionId).GetField(name).GetData(); - TransferArray(pointField, outFieldHandle, numberOfPointValuesSoFar); - } - // catch (vtkm::cont::ErrorBadValue& error) - catch (const vtkm::cont::Error& error) - { - std::cout << "Partition 0 contains an array that partition " << partitionId - << " does not contain. The merged Dataset will have random values where values " - "were missing." - << std::endl; - std::cout << error.GetMessage() << std::endl; - } - numberOfPointValuesSoFar += - partitionedDataSet.GetPartition(partitionId).GetNumberOfPoints(); - } - derivedDataSet.AddPointField(name, outFieldHandle); - } + outputData.AddField(MergeField(partitionedDataSet, fieldId, numPointsTotal, numCellsTotal)); } - return derivedDataSet; + return outputData; } }