Consolidate flying edges device cases

The flying edges code has two main implementations: one where it
traverses the mesh in the X direction (for CPU) and one where it
traverses the mesh in the Y direction (for GPU). There were several
places in the code where the device was checked for the case to use.
This resulted in several places where the "GPU" devices had to be
listed.

Improve the code by consolidate the check to a single piece of code in
`FlyingEdgesHelpers.h` where the axis to sum structure is selected. Make
the overloads based on this structure rather than directly on the
devices.
This commit is contained in:
Kenneth Moreland 2024-09-17 14:03:28 -04:00
parent 4e5fb16d00
commit cbb5871eae
2 changed files with 36 additions and 61 deletions

@ -45,8 +45,8 @@ namespace flying_edges
*
*/
template <typename Device, typename WholeEdgeField>
inline VTKM_EXEC void write_edge(Device,
template <typename WholeEdgeField>
inline VTKM_EXEC void write_edge(SumXAxis,
vtkm::Id write_index,
WholeEdgeField& edges,
vtkm::UInt8 edgeCase)
@ -55,7 +55,7 @@ inline VTKM_EXEC void write_edge(Device,
}
template <typename WholeEdgeField>
inline VTKM_EXEC void write_edge(vtkm::cont::DeviceAdapterTagCuda,
inline VTKM_EXEC void write_edge(SumYAxis,
vtkm::Id write_index,
WholeEdgeField& edges,
vtkm::UInt8 edgeCase)
@ -98,7 +98,7 @@ struct ComputePass1 : public vtkm::worklet::WorkletVisitPointsWithCells
vtkm::Id& axis_max,
WholeEdgeField& edges,
const WholeDataField& field,
Device device) const
Device) const
{
using AxisToSum = typename select_AxisToSum<Device>::type;
@ -129,7 +129,7 @@ struct ComputePass1 : public vtkm::worklet::WorkletVisitPointsWithCells
edgeCase |= FlyingEdges3D::RightAbove;
}
write_edge(device, startPos + (offset * i), edges, edgeCase);
write_edge(AxisToSum{}, startPos + (offset * i), edges, edgeCase);
if (edgeCase == FlyingEdges3D::LeftAbove || edgeCase == FlyingEdges3D::RightAbove)
{
@ -141,68 +141,42 @@ struct ComputePass1 : public vtkm::worklet::WorkletVisitPointsWithCells
}
}
}
write_edge(device, startPos + (offset * end), edges, FlyingEdges3D::Below);
write_edge(AxisToSum{}, startPos + (offset * end), edges, FlyingEdges3D::Below);
}
};
struct launchComputePass1
{
template <typename DeviceAdapterTag,
typename IVType,
typename T,
typename StorageTagField,
typename... Args>
VTKM_CONT bool LaunchXAxis(DeviceAdapterTag device,
const ComputePass1<IVType>& worklet,
const vtkm::cont::ArrayHandle<T, StorageTagField>& inputField,
vtkm::cont::ArrayHandle<vtkm::UInt8>& edgeCases,
vtkm::cont::CellSetStructured<2>& metaDataMesh2D,
Args&&... args) const
void FillEdgeCases(vtkm::cont::ArrayHandle<vtkm::UInt8>&, SumXAxis) const
{
vtkm::cont::Invoker invoke(device);
metaDataMesh2D = make_metaDataMesh2D(SumXAxis{}, worklet.PointDims);
invoke(worklet, metaDataMesh2D, std::forward<Args>(args)..., edgeCases, inputField);
return true;
// Do nothing
}
template <typename DeviceAdapterTag,
typename IVType,
typename T,
typename StorageTagField,
typename... Args>
VTKM_CONT bool LaunchYAxis(DeviceAdapterTag device,
const ComputePass1<IVType>& worklet,
const vtkm::cont::ArrayHandle<T, StorageTagField>& inputField,
vtkm::cont::ArrayHandle<vtkm::UInt8>& edgeCases,
vtkm::cont::CellSetStructured<2>& metaDataMesh2D,
Args&&... args) const
void FillEdgeCases(vtkm::cont::ArrayHandle<vtkm::UInt8>& edgeCases, SumYAxis) const
{
vtkm::cont::Invoker invoke(device);
metaDataMesh2D = make_metaDataMesh2D(SumYAxis{}, worklet.PointDims);
edgeCases.Fill(static_cast<vtkm::UInt8>(FlyingEdges3D::Below));
}
template <typename DeviceAdapterTag,
typename IVType,
typename T,
typename StorageTagField,
typename... Args>
VTKM_CONT bool operator()(DeviceAdapterTag device,
const ComputePass1<IVType>& worklet,
const vtkm::cont::ArrayHandle<T, StorageTagField>& inputField,
vtkm::cont::ArrayHandle<vtkm::UInt8>& edgeCases,
vtkm::cont::CellSetStructured<2>& metaDataMesh2D,
Args&&... args) const
{
using AxisToSum = typename select_AxisToSum<DeviceAdapterTag>::type;
vtkm::cont::Invoker invoke(device);
metaDataMesh2D = make_metaDataMesh2D(AxisToSum{}, worklet.PointDims);
this->FillEdgeCases(edgeCases, AxisToSum{});
invoke(worklet, metaDataMesh2D, std::forward<Args>(args)..., edgeCases, inputField);
return true;
}
template <typename DeviceAdapterTag, typename... Args>
VTKM_CONT bool operator()(DeviceAdapterTag device, Args&&... args) const
{
return this->LaunchXAxis(device, std::forward<Args>(args)...);
}
template <typename... Args>
VTKM_CONT bool operator()(vtkm::cont::DeviceAdapterTagCuda device, Args&&... args) const
{
return this->LaunchYAxis(device, std::forward<Args>(args)...);
}
template <typename... Args>
VTKM_CONT bool operator()(vtkm::cont::DeviceAdapterTagKokkos device, Args&&... args) const
{
return this->LaunchYAxis(device, std::forward<Args>(args)...);
}
};
}
}

@ -168,21 +168,22 @@ struct launchComputePass4
}
template <typename DeviceAdapterTag, typename... Args>
VTKM_CONT bool operator()(DeviceAdapterTag device, Args&&... args) const
VTKM_CONT bool Launch(SumXAxis, DeviceAdapterTag device, Args&&... args) const
{
return this->LaunchXAxis(device, std::forward<Args>(args)...);
}
template <typename... Args>
VTKM_CONT bool operator()(vtkm::cont::DeviceAdapterTagCuda device, Args&&... args) const
template <typename DeviceAdapterTag, typename... Args>
VTKM_CONT bool Launch(SumYAxis, DeviceAdapterTag device, Args&&... args) const
{
return this->LaunchYAxis(device, std::forward<Args>(args)...);
}
template <typename... Args>
VTKM_CONT bool operator()(vtkm::cont::DeviceAdapterTagKokkos device, Args&&... args) const
template <typename DeviceAdapterTag, typename... Args>
VTKM_CONT bool operator()(DeviceAdapterTag device, Args&&... args) const
{
return this->LaunchYAxis(device, std::forward<Args>(args)...);
return this->Launch(
(typename select_AxisToSum<DeviceAdapterTag>::type){}, device, std::forward<Args>(args)...);
}
};
}