Merge branch 'jit-lagrange' into 'master'

WIP: Jit lagrange

See merge request vtk/vtk-m!2287
This commit is contained in:
Hank 2024-07-09 17:09:20 -04:00
commit da4775fa1b
10 changed files with 390 additions and 4 deletions

@ -13,6 +13,7 @@
#include <vtkm/StaticAssert.h>
#include <vtkm/Types.h>
#include <lcl/Lagrange_Hexahedron.h>
#include <lcl/Polygon.h>
#include <lcl/Shapes.h>
@ -66,6 +67,7 @@ enum CellShapeIdEnum
CELL_SHAPE_WEDGE = lcl::ShapeId::WEDGE,
/// A pyramid with a quadrilateral base and four triangular faces.0
CELL_SHAPE_PYRAMID = lcl::ShapeId::PYRAMID,
CELL_SHAPE_LAGRANGE_HEXAHEDRON = lcl::ShapeId::LAGRANGE_HEXAHEDRON,
NUMBER_OF_CELL_SHAPES
};
@ -159,6 +161,7 @@ VTKM_DEFINE_CELL_TAG(Tetra, CELL_SHAPE_TETRA);
VTKM_DEFINE_CELL_TAG(Hexahedron, CELL_SHAPE_HEXAHEDRON);
VTKM_DEFINE_CELL_TAG(Wedge, CELL_SHAPE_WEDGE);
VTKM_DEFINE_CELL_TAG(Pyramid, CELL_SHAPE_PYRAMID);
VTKM_DEFINE_CELL_TAG(Lagrange_Hexahedron, CELL_SHAPE_LAGRANGE_HEXAHEDRON);
#undef VTKM_DEFINE_CELL_TAG
@ -200,6 +203,13 @@ inline lcl::Polygon make_LclCellShapeTag(const vtkm::CellShapeTagPolygon&,
return lcl::Polygon(numPoints);
}
VTKM_EXEC_CONT
inline lcl::Lagrange_Hexahedron make_LclCellShapeTag(const vtkm::CellShapeTagLagrange_Hexahedron&,
vtkm::IdComponent numPoints = 0)
{
return lcl::Lagrange_Hexahedron(numPoints);
}
VTKM_EXEC_CONT
inline lcl::Cell make_LclCellShapeTag(const vtkm::CellShapeTagGeneric& tag,
vtkm::IdComponent numPoints = 0)
@ -258,7 +268,8 @@ inline lcl::Cell make_LclCellShapeTag(const vtkm::CellShapeTagGeneric& tag,
vtkmGenericCellShapeMacroCase(CELL_SHAPE_TETRA, call); \
vtkmGenericCellShapeMacroCase(CELL_SHAPE_HEXAHEDRON, call); \
vtkmGenericCellShapeMacroCase(CELL_SHAPE_WEDGE, call); \
vtkmGenericCellShapeMacroCase(CELL_SHAPE_PYRAMID, call)
vtkmGenericCellShapeMacroCase(CELL_SHAPE_PYRAMID, call); \
vtkmGenericCellShapeMacroCase(CELL_SHAPE_LAGRANGE_HEXAHEDRON, call);
} // namespace vtkm

@ -115,6 +115,7 @@ VTKM_DEFINE_CELL_TRAITS(Tetra, 3, 4);
VTKM_DEFINE_CELL_TRAITS(Hexahedron, 3, 8);
VTKM_DEFINE_CELL_TRAITS(Wedge, 3, 6);
VTKM_DEFINE_CELL_TRAITS(Pyramid, 3, 5);
VTKM_DEFINE_CELL_TRAITS_VARIABLE(Lagrange_Hexahedron, 3);
#undef VTKM_DEFINE_CELL_TRAITS

@ -40,11 +40,20 @@ public:
: Indices(indices)
, Portal(portal)
{
Limit = 0;
}
void LimitNumberOfComponents(vtkm::IdComponent l) { Limit = l; }
VTKM_SUPPRESS_EXEC_WARNINGS
VTKM_EXEC_CONT
vtkm::IdComponent GetNumberOfComponents() const { return this->Indices->GetNumberOfComponents(); }
vtkm::IdComponent GetNumberOfComponents() const
{
vtkm::IdComponent numComps = this->Indices->GetNumberOfComponents();
if (Limit > 0)
return (Limit < numComps ? Limit : numComps);
return numComps;
}
VTKM_SUPPRESS_EXEC_WARNINGS
template <vtkm::IdComponent DestSize>
@ -67,6 +76,7 @@ public:
private:
const IndexVecType* const Indices;
PortalType Portal;
vtkm::IdComponent Limit;
};
template <typename IndexVecType, typename PortalType>

@ -160,6 +160,19 @@ static inline VTKM_EXEC vtkm::ErrorCode CellEdgeNumberOfEdges(vtkm::IdComponent
/// @param[in] shape A tag of type `CellShapeTag*` to identify the shape of the cell.
/// This method is overloaded for different shape types.
/// @param[out] numEdges A reference to return the number of edges.
static inline VTKM_EXEC vtkm::ErrorCode CellEdgeNumberOfEdges(vtkm::IdComponent numPoints,
vtkm::CellShapeTagLagrange_Hexahedron,
vtkm::IdComponent& numEdges)
{
if (numPoints <= 0)
{
numEdges = -1;
return vtkm::ErrorCode::InvalidNumberOfPoints;
}
numEdges = numPoints; // HC: this is what SW did, but possibly it was a shortcut
return vtkm::ErrorCode::Success;
}
static inline VTKM_EXEC vtkm::ErrorCode CellEdgeNumberOfEdges(vtkm::IdComponent numPoints,
vtkm::CellShapeTagGeneric shape,
vtkm::IdComponent& numEdges)

@ -19,6 +19,7 @@
#include <vtkm/internal/ArrayPortalUniformPointCoordinates.h>
#include <vtkm/VecAxisAlignedPointCoordinates.h>
#include <vtkm/exec/ConnectivityExplicit.h>
#include <vtkm/exec/ConnectivityExtrude.h>
#include <vtkm/exec/ConnectivityStructured.h>
@ -154,6 +155,67 @@ struct FetchArrayTopologyMapInImplementation<
}
};
template <typename PortalType>
VTKM_EXEC inline void PerformHigherOrderToLowerOrderCellSubstitution(
PortalType pt,
vtkm::CellShapeTagLagrange_Hexahedron)
{
pt.LimitNumberOfComponents(8);
}
template <typename PortalType>
VTKM_EXEC inline void PerformHigherOrderToLowerOrderCellSubstitution(
PortalType pt,
vtkm::CellShapeTagGeneric shape)
{
if (shape.Id == vtkm::CELL_SHAPE_LAGRANGE_HEXAHEDRON)
pt.LimitNumberOfComponents(8);
}
template <typename PortalType, typename CellShapeTag>
VTKM_EXEC inline void PerformHigherOrderToLowerOrderCellSubstitution(PortalType, CellShapeTag)
{
return;
}
template <typename PortalType, typename CellShapeTag>
VTKM_EXEC inline void PerformInlineShapeSubstitution(PortalType pt, CellShapeTag shape)
{
PerformHigherOrderToLowerOrderCellSubstitution(pt, shape);
}
template <typename FieldExecObjectType, typename X, typename Y, typename Z, typename W>
struct FetchArrayTopologyMapInImplementation<
vtkm::exec::ConnectivityExplicit<X, Y, Z>,
FieldExecObjectType,
ThreadIndicesTopologyMap<vtkm::exec::ConnectivityExplicit<X, Y, Z>, W>>
{
using ThreadIndicesTopologyMapType =
const ThreadIndicesTopologyMap<vtkm::exec::ConnectivityExplicit<X, Y, Z>, W>;
// stored in a Vec-like object.
using IndexVecType = typename ThreadIndicesTopologyMapType::IndicesIncidentType;
// The FieldExecObjectType is expected to behave like an ArrayPortal.
using PortalType = FieldExecObjectType;
using ValueType = vtkm::VecFromPortalPermute<IndexVecType, PortalType>;
VTKM_SUPPRESS_EXEC_WARNINGS
VTKM_EXEC
static ValueType Load(ThreadIndicesTopologyMapType& indices, const FieldExecObjectType& field)
{
std::cerr << "LOAD5" << std::endl;
ValueType rv(indices.GetIndicesIncidentPointer(), field);
PerformInlineShapeSubstitution(rv, indices.GetCellShape());
//if (indices.GetCellShape() == vtkm::CellShapeTagLagrange_Hexahedron())
//rv.LimitNumberOfComponents(8);
return rv;
}
};
template <typename PermutationPortal, vtkm::IdComponent NumDimensions, typename ThreadIndicesType>
struct FetchArrayTopologyMapInImplementation<
vtkm::exec::ConnectivityPermutedVisitCellsWithPoints<

@ -221,6 +221,16 @@ struct TestCellFacesFunctor
this->TryShapeWithNumPoints(numPoints, vtkm::CellShapeTagPolygon());
}
}
void operator()(vtkm::CellShapeTagLagrange_Hexahedron) const
{
// n^3
for (vtkm::IdComponent numPoints = 8; numPoints < 28;
numPoints = static_cast<vtkm::IdComponent>(vtkm::Pow(numPoints, 3.)))
{
this->TryShapeWithNumPoints(numPoints, vtkm::CellShapeTagLagrange_Hexahedron());
}
}
};
void TestAllShapes()

@ -0,0 +1,274 @@
//============================================================================
// Copyright (c) Kitware, Inc.
// All rights reserved.
// See LICENSE.md for details.
//
// This software is distributed WITHOUT ANY WARRANTY; without even
// the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
// PURPOSE. See the above copyright notice for more information.
//============================================================================
#ifndef lcl_Lagrange_Hexahedron_h
#define lcl_Lagrange_Hexahedron_h
#include <vtkm/Math.h>
#include <lcl/ErrorCode.h>
#include <lcl/Shapes.h>
#include <lcl/internal/Common.h>
namespace lcl
{
class Lagrange_Hexahedron : public Cell
{
public:
constexpr LCL_EXEC Lagrange_Hexahedron() : Cell(ShapeId::LAGRANGE_HEXAHEDRON, 8) {}
constexpr LCL_EXEC explicit Lagrange_Hexahedron(lcl::IdComponent numPoints)
: Cell(ShapeId::LAGRANGE_HEXAHEDRON, numPoints)
{
}
constexpr LCL_EXEC explicit Lagrange_Hexahedron(const Cell& cell) : Cell(cell) {}
};
LCL_EXEC inline lcl::ErrorCode validate(Lagrange_Hexahedron tag) noexcept
{
if (tag.shape() != ShapeId::LAGRANGE_HEXAHEDRON && tag.shape() != ShapeId::VOXEL)
{
return ErrorCode::WRONG_SHAPE_ID_FOR_TAG_TYPE;
}
if (static_cast<lcl::IdComponent>(vtkm::Cbrt(tag.numberOfPoints()))%3 != 0) // 0th order Lagrange is a Hex
{
return ErrorCode::INVALID_NUMBER_OF_POINTS;
}
return ErrorCode::SUCCESS;
}
template<typename CoordType>
LCL_EXEC inline lcl::ErrorCode parametricCenter(Lagrange_Hexahedron, CoordType&& pcoords) noexcept
{
LCL_STATIC_ASSERT_PCOORDS_IS_FLOAT_TYPE(CoordType);
// All Lagrange Unit Hexes have the same parametric center
component(pcoords, 0) = 0.5f;
component(pcoords, 1) = 0.5f;
component(pcoords, 2) = 0.5f;
return ErrorCode::SUCCESS;
}
/*
* Ordering inspired by vtk quadratic hex
* First 8 points are the normal unit hex
* Next 8 points are the midpoints for the unit hex (starts <0.5,0,0>)
* Next 4 points is unit square at z=0.5 (starts <0,0,0.5>)
* Next 4 points are the midpoints of unit square at z=0.5 (starts <0.5,0,0.5>)
* Final 3 points are the vertical central line (<0.5,0.5,0> -> <0.5,0.5,1>)
*/
template<typename CoordType>
LCL_EXEC inline lcl::ErrorCode parametricPoint(
Lagrange_Hexahedron, IdComponent pointId, CoordType&& pcoords) noexcept
{
LCL_STATIC_ASSERT_PCOORDS_IS_FLOAT_TYPE(CoordType);
switch (pointId)
{
case 0:
component(pcoords, 0) = 0.0f;
component(pcoords, 1) = 0.0f;
component(pcoords, 2) = 0.0f;
break;
case 1:
component(pcoords, 0) = 1.0f;
component(pcoords, 1) = 0.0f;
component(pcoords, 2) = 0.0f;
break;
case 2:
component(pcoords, 0) = 1.0f;
component(pcoords, 1) = 1.0f;
component(pcoords, 2) = 0.0f;
break;
case 3:
component(pcoords, 0) = 0.0f;
component(pcoords, 1) = 1.0f;
component(pcoords, 2) = 0.0f;
break;
case 4:
component(pcoords, 0) = 0.0f;
component(pcoords, 1) = 0.0f;
component(pcoords, 2) = 1.0f;
break;
case 5:
component(pcoords, 0) = 1.0f;
component(pcoords, 1) = 0.0f;
component(pcoords, 2) = 1.0f;
break;
case 6:
component(pcoords, 0) = 1.0f;
component(pcoords, 1) = 1.0f;
component(pcoords, 2) = 1.0f;
break;
case 7:
component(pcoords, 0) = 0.0f;
component(pcoords, 1) = 1.0f;
component(pcoords, 2) = 1.0f;
break;
default:
return ErrorCode::INVALID_POINT_ID;
}
return ErrorCode::SUCCESS;
}
template<typename CoordType>
LCL_EXEC inline ComponentType<CoordType> parametricDistance(Lagrange_Hexahedron, const CoordType& pcoords) noexcept
{
LCL_STATIC_ASSERT_PCOORDS_IS_FLOAT_TYPE(CoordType);
return internal::findParametricDistance(pcoords, 3);
}
template<typename CoordType>
LCL_EXEC inline bool cellInside(Lagrange_Hexahedron, const CoordType& pcoords) noexcept
{
LCL_STATIC_ASSERT_PCOORDS_IS_FLOAT_TYPE(CoordType);
using T = ComponentType<CoordType>;
constexpr T eps = 1e-6f;
return component(pcoords, 0) >= -eps && component(pcoords, 0) <= (T{1} + eps) &&
component(pcoords, 1) >= -eps && component(pcoords, 1) <= (T{1} + eps) &&
component(pcoords, 2) >= -eps && component(pcoords, 2) <= (T{1} + eps);
}
// TODO: How does this work with Lagrange?
// Implementation should ignore midpoints and approximate as linear cell
template <typename Values, typename CoordType, typename Result>
LCL_EXEC inline lcl::ErrorCode interpolate(
Lagrange_Hexahedron,
const Values& values,
const CoordType& pcoords,
Result&& result) noexcept
{
LCL_STATIC_ASSERT_PCOORDS_IS_FLOAT_TYPE(CoordType);
using T = internal::ClosestFloatType<typename Values::ValueType>;
for (IdComponent c = 0; c < values.getNumberOfComponents(); ++c)
{
auto vbf = internal::lerp(static_cast<T>(values.getValue(0, c)),
static_cast<T>(values.getValue(1, c)),
static_cast<T>(component(pcoords, 0)));
auto vbb = internal::lerp(static_cast<T>(values.getValue(3, c)),
static_cast<T>(values.getValue(2, c)),
static_cast<T>(component(pcoords, 0)));
auto vtf = internal::lerp(static_cast<T>(values.getValue(4, c)),
static_cast<T>(values.getValue(5, c)),
static_cast<T>(component(pcoords, 0)));
auto vtb = internal::lerp(static_cast<T>(values.getValue(7, c)),
static_cast<T>(values.getValue(6, c)),
static_cast<T>(component(pcoords, 0)));
auto vb = internal::lerp(vbf, vbb, static_cast<T>(component(pcoords, 1)));
auto vt = internal::lerp(vtf, vtb, static_cast<T>(component(pcoords, 1)));
auto v = internal::lerp(vb, vt, static_cast<T>(component(pcoords, 2)));
component(result, c) = static_cast<ComponentType<Result>>(v);
}
return ErrorCode::SUCCESS;
}
namespace internal
{
// TODO: Fix for 27
// Current is Linear?
template <typename Values, typename CoordType, typename Result>
LCL_EXEC inline void parametricDerivative(Lagrange_Hexahedron,
const Values& values,
IdComponent comp,
const CoordType& pcoords,
Result&& result) noexcept
{
using T = internal::ClosestFloatType<typename Values::ValueType>;
T p0 = static_cast<T>(component(pcoords, 0));
T p1 = static_cast<T>(component(pcoords, 1));
T p2 = static_cast<T>(component(pcoords, 2));
T rm = T{1} - p0;
T sm = T{1} - p1;
T tm = T{1} - p2;
T dr = (static_cast<T>(values.getValue(0, comp)) * -sm * tm) +
(static_cast<T>(values.getValue(1, comp)) * sm * tm) +
(static_cast<T>(values.getValue(2, comp)) * p1 * tm) +
(static_cast<T>(values.getValue(3, comp)) * -p1 * tm) +
(static_cast<T>(values.getValue(4, comp)) * -sm * p2) +
(static_cast<T>(values.getValue(5, comp)) * sm * p2) +
(static_cast<T>(values.getValue(6, comp)) * p1 * p2) +
(static_cast<T>(values.getValue(7, comp)) * -p1 * p2);
T ds = (static_cast<T>(values.getValue(0, comp)) * -rm * tm) +
(static_cast<T>(values.getValue(1, comp)) * -p0 * tm) +
(static_cast<T>(values.getValue(2, comp)) * p0 * tm) +
(static_cast<T>(values.getValue(3, comp)) * rm * tm) +
(static_cast<T>(values.getValue(4, comp)) * -rm * p2) +
(static_cast<T>(values.getValue(5, comp)) * -p0 * p2) +
(static_cast<T>(values.getValue(6, comp)) * p0 * p2) +
(static_cast<T>(values.getValue(7, comp)) * rm * p2);
T dt = (static_cast<T>(values.getValue(0, comp)) * -rm * sm) +
(static_cast<T>(values.getValue(1, comp)) * -p0 * sm) +
(static_cast<T>(values.getValue(2, comp)) * -p0 * p1) +
(static_cast<T>(values.getValue(3, comp)) * -rm * p1) +
(static_cast<T>(values.getValue(4, comp)) * rm * sm) +
(static_cast<T>(values.getValue(5, comp)) * p0 * sm) +
(static_cast<T>(values.getValue(6, comp)) * p0 * p1) +
(static_cast<T>(values.getValue(7, comp)) * rm * p1);
component(result, 0) = static_cast<ComponentType<Result>>(dr);
component(result, 1) = static_cast<ComponentType<Result>>(ds);
component(result, 2) = static_cast<ComponentType<Result>>(dt);
}
} // internal
template <typename Points, typename Values, typename CoordType, typename Result>
LCL_EXEC inline lcl::ErrorCode derivative(
Lagrange_Hexahedron,
const Points& points,
const Values& values,
const CoordType& pcoords,
Result&& dx,
Result&& dy,
Result&& dz) noexcept
{
return internal::derivative3D(Lagrange_Hexahedron{},
points,
values,
pcoords,
std::forward<Result>(dx),
std::forward<Result>(dy),
std::forward<Result>(dz));
}
template <typename Points, typename PCoordType, typename WCoordType>
LCL_EXEC inline lcl::ErrorCode parametricToWorld(
Lagrange_Hexahedron,
const Points& points,
const PCoordType& pcoords,
WCoordType&& wcoords) noexcept
{
return interpolate(Lagrange_Hexahedron{}, points, pcoords, std::forward<WCoordType>(wcoords));
}
template <typename Points, typename WCoordType, typename PCoordType>
LCL_EXEC inline lcl::ErrorCode worldToParametric(
Lagrange_Hexahedron,
const Points& points,
const WCoordType& wcoords,
PCoordType&& pcoords) noexcept
{
return internal::worldToParametric3D(
Lagrange_Hexahedron{}, points, wcoords, std::forward<PCoordType>(pcoords));
}
} // lcl
#endif // lcl_Lagrange_Hexahedron_h

@ -35,7 +35,7 @@ enum ShapeId : IdShape
HEXAHEDRON = 12,
WEDGE = 13,
PYRAMID = 14,
LAGRANGE_HEXAHEDRON = 72,
NUMBER_OF_CELL_SHAPES
};
@ -87,6 +87,7 @@ inline LCL_EXEC int dimension(IdShape shapeId)
case HEXAHEDRON:
case WEDGE:
case PYRAMID:
case LAGRANGE_HEXAHEDRON:
return 3;
case EMPTY:
default:
@ -115,6 +116,7 @@ class Voxel;
class Hexahedron;
class Wedge;
class Pyramid;
class Lagrange_Hexahedron;
} //namespace lcl
@ -137,6 +139,7 @@ class Pyramid;
lclGenericCellShapeMacroCase(lcl::ShapeId::VOXEL, lcl::Voxel, call); \
lclGenericCellShapeMacroCase(lcl::ShapeId::HEXAHEDRON, lcl::Hexahedron, call); \
lclGenericCellShapeMacroCase(lcl::ShapeId::WEDGE, lcl::Wedge, call); \
lclGenericCellShapeMacroCase(lcl::ShapeId::PYRAMID, lcl::Pyramid, call)
lclGenericCellShapeMacroCase(lcl::ShapeId::PYRAMID, lcl::Pyramid, call); \
lclGenericCellShapeMacroCase(lcl::ShapeId::LAGRANGE_HEXAHEDRON, lcl::Lagrange_Hexahedron, call)
#endif //lcl_Shapes_h

@ -66,6 +66,7 @@ FORWARD_DECLAR_PARAMETRIC_DERIVATIVE(Tetra);
FORWARD_DECLAR_PARAMETRIC_DERIVATIVE(Hexahedron);
FORWARD_DECLAR_PARAMETRIC_DERIVATIVE(Wedge);
FORWARD_DECLAR_PARAMETRIC_DERIVATIVE(Pyramid);
FORWARD_DECLAR_PARAMETRIC_DERIVATIVE(Lagrange_Hexahedron);
#undef FORWARD_DECLAR_PARAMETRIC_DERIVATIVE

@ -21,6 +21,7 @@
#include <lcl/Vertex.h>
#include <lcl/Voxel.h>
#include <lcl/Wedge.h>
#include <lcl/Lagrange_Hexahedron.h>
#include <utility>