add test and fix cuda long compiling issue

This commit is contained in:
zhe 2024-02-05 16:14:19 -05:00
parent fc92eafb76
commit 59c545c15c
2 changed files with 59 additions and 135 deletions

@ -26,23 +26,6 @@
namespace
{
template <typename T>
struct SetToInvalid : public vtkm::worklet::WorkletMapField
{
T InvalidValue;
SetToInvalid(T invalidValue)
: InvalidValue(invalidValue)
{
}
typedef void ControlSignature(FieldInOut);
typedef void ExecutionSignature(_1);
template <typename ValueType>
VTKM_EXEC void operator()(ValueType& value) const
{
value = this->InvalidValue;
}
};
struct CopyWithOffsetWorklet : public vtkm::worklet::WorkletMapField
{
vtkm::Id OffsetValue;
@ -271,121 +254,6 @@ vtkm::cont::CellSetExplicit<> MergeCellSetsExplicit(
return outCells;
}
struct ClearPartitionWorklet : vtkm::worklet::WorkletMapField
{
using ControlSignature = void(FieldIn indices, WholeArrayInOut array);
using ExecutionSignature = void(WorkIndex, _2);
vtkm::Id IndexOffset;
ClearPartitionWorklet(vtkm::Id indexOffset)
: IndexOffset(indexOffset)
{
}
template <typename OutPortalType>
VTKM_EXEC void operator()(vtkm::Id index, OutPortalType& outPortal) const
{
// It's weird to get a value from a portal only to override it, but the expect type
// is weird (a variable-sized Vec), so this is the only practical way to set it.
auto outVec = outPortal.Get(index + this->IndexOffset);
for (vtkm::IdComponent comp = 0; comp < outVec.GetNumberOfComponents(); ++comp)
{
outVec[comp] = 0;
}
// Shouldn't really do anything.
outPortal.Set(index + this->IndexOffset, outVec);
}
};
template <typename OutArrayHandle>
void ClearPartition(OutArrayHandle& outArray, vtkm::Id startIndex, vtkm::Id numValues)
{
vtkm::cont::Invoker invoke;
invoke(ClearPartitionWorklet{ startIndex }, vtkm::cont::ArrayHandleIndex(numValues), outArray);
}
struct CopyPartitionWorklet : vtkm::worklet::WorkletMapField
{
using ControlSignature = void(FieldIn sourceArray, WholeArrayInOut mergedArray);
using ExecutionSignature = void(WorkIndex, _1, _2);
vtkm::Id IndexOffset;
CopyPartitionWorklet(vtkm::Id indexOffset)
: IndexOffset(indexOffset)
{
}
template <typename InVecType, typename OutPortalType>
VTKM_EXEC void operator()(vtkm::Id index, const InVecType& inVec, OutPortalType& outPortal) const
{
// It's weird to get a value from a portal only to override it, but the expect type
// is weird (a variable-sized Vec), so this is the only practical way to set it.
auto outVec = outPortal.Get(index + this->IndexOffset);
VTKM_ASSERT(inVec.GetNumberOfComponents() == outVec.GetNumberOfComponents());
for (vtkm::IdComponent comp = 0; comp < outVec.GetNumberOfComponents(); ++comp)
{
outVec[comp] = static_cast<typename decltype(outVec)::ComponentType>(inVec[comp]);
}
// Shouldn't really do anything.
outPortal.Set(index + this->IndexOffset, outVec);
}
};
template <typename OutArrayHandle>
void CopyPartition(const vtkm::cont::Field& inField, OutArrayHandle& outArray, vtkm::Id startIndex)
{
vtkm::cont::Invoker invoke;
using ComponentType = typename OutArrayHandle::ValueType::ComponentType;
if (inField.GetData().IsBaseComponentType<ComponentType>())
{
invoke(CopyPartitionWorklet{ startIndex },
inField.GetData().ExtractArrayFromComponents<ComponentType>(),
outArray);
}
else
{
VTKM_LOG_S(vtkm::cont::LogLevel::Info,
"Discovered mismatched types for field " << inField.GetName()
<< ". Requires extra copy.");
invoke(CopyPartitionWorklet{ startIndex },
inField.GetDataAsDefaultFloat().ExtractArrayFromComponents<vtkm::FloatDefault>(),
outArray);
}
}
template <typename HasFieldFunctor, typename GetFieldFunctor>
vtkm::cont::UnknownArrayHandle MergeArray(vtkm::Id numPartitions,
HasFieldFunctor&& hasField,
GetFieldFunctor&& getField,
vtkm::Id totalSize)
{
vtkm::cont::UnknownArrayHandle mergedArray = getField(0).GetData().NewInstanceBasic();
mergedArray.Allocate(totalSize);
vtkm::Id startIndex = 0;
for (vtkm::Id partitionId = 0; partitionId < numPartitions; ++partitionId)
{
vtkm::Id partitionSize;
if (hasField(partitionId, partitionSize))
{
vtkm::cont::Field sourceField = getField(partitionId);
mergedArray.CastAndCallWithExtractedArray(
[=](auto array) { CopyPartition(sourceField, array, startIndex); });
}
else
{
mergedArray.CastAndCallWithExtractedArray(
[=](auto array) { ClearPartition(array, startIndex, partitionSize); });
}
startIndex += partitionSize;
}
VTKM_ASSERT(startIndex == totalSize);
return mergedArray;
}
vtkm::Id GetFirstEmptyPartition(const vtkm::cont::PartitionedDataSet& partitionedDataSet)
{
vtkm::Id numOfDataSet = partitionedDataSet.GetNumberOfPartitions();
@ -588,8 +456,13 @@ void MergeFieldsAndAddIntoDataSet(vtkm::cont::DataSet& outputDataSet,
{
copySize = partitionedDataSet.GetPartition(partitionIndex).GetNumberOfCells();
}
auto viewOut = vtkm::cont::make_ArrayHandleView(concreteOut, offset, copySize);
invoke(SetToInvalid<ComponentType>{ castInvalid }, viewOut);
for (vtkm::IdComponent component = 0; component < concreteOut.GetNumberOfComponents();
++component)
{
//Extracting each component from RecombineVec and copy invalid value into it
//Avoid using invoke to call worklet on ArrayHandleRecombineVec (it may cause long compiling issue on CUDA 12.x).
concreteOut.GetComponentArray(component).Fill(castInvalid, offset, offset + copySize);
}
offset += copySize;
}
}

@ -590,6 +590,55 @@ void TestEmptyPartitions()
"wrong cellVar values");
}
void TestMissingVectorFields()
{
std::cout << "TestMissingVectorFields" << std::endl;
vtkm::cont::DataSetBuilderUniform dsb;
vtkm::Id2 dimensions(3, 2);
vtkm::cont::DataSet dataSet1 = dsb.Create(dimensions, vtkm::Vec2f(0.0, 0.0), vtkm::Vec2f(1, 1));
vtkm::cont::ArrayHandle<vtkm::Vec<vtkm::Float64, 4>> pointVarVec4;
pointVarVec4.Allocate(6);
vtkm::cont::Invoker invoker;
invoker(SetPointValuesV4Worklet{}, dataSet1.GetCoordinateSystem().GetData(), pointVarVec4);
dataSet1.AddPointField("pointVarV4", pointVarVec4);
vtkm::cont::DataSet dataSet2 = dsb.Create(dimensions, vtkm::Vec2f(0.0, 0.0), vtkm::Vec2f(1, 1));
vtkm::cont::ArrayHandle<vtkm::Vec3f_64> cellVarVec3 =
vtkm::cont::make_ArrayHandle<vtkm::Vec3f_64>({ { 1.0, 2.0, 3.0 }, { 4.0, 5.0, 6.0 } });
dataSet2.AddCellField("cellVarV3", cellVarVec3);
vtkm::cont::PartitionedDataSet inputDataSets;
inputDataSets.AppendPartition(dataSet1);
inputDataSets.AppendPartition(dataSet2);
vtkm::filter::multi_block::MergeDataSets mergeDataSets;
auto result = mergeDataSets.Execute(inputDataSets);
//checking results
vtkm::cont::ArrayHandle<vtkm::Vec<vtkm::Float64, 4>> validatePointVar =
vtkm::cont::make_ArrayHandle<vtkm::Vec<vtkm::Float64, 4>>(
{ { 0, 0, 0, 0 },
{ 0.1, 0, 0, 0.1 },
{ 0.2, 0, 0, 0.2 },
{ 0, 0.1, 0, 0 },
{ 0.1, 0.1, 0, 0.1 },
{ 0.2, 0.1, 0, 0.2 },
{ vtkm::Nan64(), vtkm::Nan64(), vtkm::Nan64(), vtkm::Nan64() },
{ vtkm::Nan64(), vtkm::Nan64(), vtkm::Nan64(), vtkm::Nan64() },
{ vtkm::Nan64(), vtkm::Nan64(), vtkm::Nan64(), vtkm::Nan64() },
{ vtkm::Nan64(), vtkm::Nan64(), vtkm::Nan64(), vtkm::Nan64() },
{ vtkm::Nan64(), vtkm::Nan64(), vtkm::Nan64(), vtkm::Nan64() },
{ vtkm::Nan64(), vtkm::Nan64(), vtkm::Nan64(), vtkm::Nan64() } });
vtkm::cont::ArrayHandle<vtkm::Vec3f_64> validateCellVar =
vtkm::cont::make_ArrayHandle<vtkm::Vec3f_64>({ { vtkm::Nan64(), vtkm::Nan64(), vtkm::Nan64() },
{ vtkm::Nan64(), vtkm::Nan64(), vtkm::Nan64() },
{ 1.0, 2.0, 3.0 },
{ 4.0, 5.0, 6.0 } });
VTKM_TEST_ASSERT(test_equal_ArrayHandles(result.GetPartition(0).GetField("pointVarV4").GetData(),
validatePointVar),
"wrong point values for TestMissingVectorFields");
VTKM_TEST_ASSERT(test_equal_ArrayHandles(result.GetPartition(0).GetField("cellVarV3").GetData(),
validateCellVar),
"wrong cell values for TestMissingVectorFields");
}
void TestMergeDataSetsFilter()
{
//same cell type (triangle), same field name, same data type, cellset is single type
@ -606,12 +655,14 @@ void TestMergeDataSetsFilter()
TestDiffCellsSameFieldsSameDataType();
//test multiple partitions
TestMoreThanTwoPartitions();
//some partitions have missing fields
//some partitions have missing scalar fields
TestMissingFieldsAndSameFieldName();
//test empty partitions
TestEmptyPartitions();
//test customized types
TestCustomizedVecField();
//some partitions have missing vector fields
TestMissingVectorFields();
}
} // anonymous namespace
int UnitTestMergeDataSetsFilter(int argc, char* argv[])