Merge branch 'pointlocator' into pointlocator2
This commit is contained in:
commit
2e519f6508
@ -28,7 +28,7 @@
|
||||
|
||||
namespace vtkm
|
||||
{
|
||||
namespace worklet
|
||||
namespace cont
|
||||
{
|
||||
template <typename T>
|
||||
class PointLocatorUniformGrid
|
||||
@ -71,46 +71,56 @@ public:
|
||||
vtkm::Vec<T, 3> Dxdydz;
|
||||
};
|
||||
|
||||
class UniformGridSearch : public vtkm::worklet::WorkletMapField
|
||||
template <typename DeviceAdapter>
|
||||
class Locator : public vtkm::exec::ExecutionObjectBase
|
||||
{
|
||||
public:
|
||||
using ControlSignature = void(FieldIn<> query,
|
||||
WholeArrayIn<> coordIn,
|
||||
WholeArrayIn<IdType> pointId,
|
||||
WholeArrayIn<IdType> cellLower,
|
||||
WholeArrayIn<IdType> cellUpper,
|
||||
FieldOut<IdType> neighborId,
|
||||
FieldOut<> distance);
|
||||
|
||||
using ExecutionSignature = void(_1, _2, _3, _4, _5, _6, _7);
|
||||
using CoordPortalType = typename vtkm::cont::ArrayHandle<
|
||||
vtkm::Vec<T, 3>>::template ExecutionTypes<DeviceAdapter>::PortalConst;
|
||||
using IdPortalType = typename vtkm::cont::ArrayHandle<vtkm::Id>::template ExecutionTypes<
|
||||
DeviceAdapter>::PortalConst;
|
||||
|
||||
VTKM_CONT
|
||||
UniformGridSearch(const vtkm::Vec<T, 3>& _min,
|
||||
const vtkm::Vec<T, 3>& _max,
|
||||
const vtkm::Vec<vtkm::Id, 3>& _dims)
|
||||
Locator() = default;
|
||||
|
||||
VTKM_CONT
|
||||
Locator(const vtkm::Vec<T, 3>& _min,
|
||||
const vtkm::Vec<T, 3>& _max,
|
||||
const vtkm::Vec<vtkm::Id, 3>& _dims,
|
||||
const CoordPortalType& coords,
|
||||
const IdPortalType& pointIds,
|
||||
const IdPortalType& cellLower,
|
||||
const IdPortalType& cellUpper)
|
||||
: Min(_min)
|
||||
, Dims(_dims)
|
||||
, Dxdydz((_max - _min) / _dims)
|
||||
, Dxdydz((_max - Min) / Dims)
|
||||
, coords(coords)
|
||||
, pointIds(pointIds)
|
||||
, cellLower(cellLower)
|
||||
, cellUpper(cellUpper)
|
||||
{
|
||||
}
|
||||
|
||||
|
||||
template <typename CoordiVecType,
|
||||
typename IdPortalType,
|
||||
typename CoordiPortalType,
|
||||
typename IdType,
|
||||
typename CoordiType>
|
||||
VTKM_EXEC void operator()(const CoordiVecType& queryCoord,
|
||||
const CoordiPortalType& coordi_Handle,
|
||||
const IdPortalType& pointId,
|
||||
const IdPortalType& cellLower,
|
||||
const IdPortalType& cellUpper,
|
||||
IdType& nnId,
|
||||
CoordiType& nnDis) const
|
||||
/// \brief Nearest neighbor search using a Uniform Grid
|
||||
///
|
||||
/// Parallel search of nearesat neighbor for each point in the \c queryPoints in the set of
|
||||
/// \c coords. Returns neareast neighbot in \c nearestNeighborIds and distances to nearest
|
||||
/// neighbor in \c distances.
|
||||
///
|
||||
/// \param coords Point coordinates for training dataset.
|
||||
/// \param queryPoints Point coordinates to query for nearest neighbors.
|
||||
/// \param nearestNeighborIds Neareast neighbor in the training dataset for each points in
|
||||
/// the test set
|
||||
/// \param distances Distance between query points and their nearest neighbors.
|
||||
/// \param device Tag for selecting device adapter.
|
||||
VTKM_EXEC
|
||||
void FindNearestPoint(const vtkm::Vec<T, 3>& queryPoint,
|
||||
vtkm::Id& nearestNeighborId,
|
||||
T& distance) const
|
||||
{
|
||||
auto nlayers = vtkm::Max(vtkm::Max(Dims[0], Dims[1]), Dims[2]);
|
||||
|
||||
vtkm::Vec<vtkm::Id, 3> xyz = (queryCoord - Min) / Dxdydz;
|
||||
vtkm::Vec<vtkm::Id, 3> xyz = (queryPoint - Min) / Dxdydz;
|
||||
|
||||
float min_distance = std::numeric_limits<float>::max();
|
||||
vtkm::Id neareast = -1;
|
||||
@ -138,11 +148,11 @@ public:
|
||||
auto upper = cellUpper.Get(cellid);
|
||||
for (auto index = lower; index < upper; index++)
|
||||
{
|
||||
auto pointid = pointId.Get(index);
|
||||
auto point = coordi_Handle.Get(pointid);
|
||||
auto dx = point[0] - queryCoord[0];
|
||||
auto dy = point[1] - queryCoord[1];
|
||||
auto dz = point[2] - queryCoord[2];
|
||||
auto pointid = pointIds.Get(index);
|
||||
auto point = coords.Get(pointid);
|
||||
auto dx = point[0] - queryPoint[0];
|
||||
auto dy = point[1] - queryPoint[1];
|
||||
auto dz = point[2] - queryPoint[2];
|
||||
|
||||
auto distance2 = dx * dx + dy * dy + dz * dz;
|
||||
if (distance2 < min_distance)
|
||||
@ -158,14 +168,21 @@ public:
|
||||
}
|
||||
}
|
||||
|
||||
nnId = neareast;
|
||||
nnDis = vtkm::Sqrt(min_distance);
|
||||
}
|
||||
nearestNeighborId = neareast;
|
||||
distance = vtkm::Sqrt(min_distance);
|
||||
};
|
||||
|
||||
private:
|
||||
vtkm::Vec<T, 3> Min;
|
||||
vtkm::Vec<vtkm::Id, 3> Dims;
|
||||
vtkm::Vec<T, 3> Dxdydz;
|
||||
|
||||
CoordPortalType coords;
|
||||
|
||||
IdPortalType pointIds;
|
||||
IdPortalType cellIds;
|
||||
IdPortalType cellLower;
|
||||
IdPortalType cellUpper;
|
||||
};
|
||||
|
||||
/// \brief Construct a 3D uniform grid for nearest neighbor search.
|
||||
@ -177,6 +194,10 @@ public:
|
||||
DeviceAdapter vtkmNotUsed(device))
|
||||
{
|
||||
using Algorithm = vtkm::cont::DeviceAdapterAlgorithm<DeviceAdapter>;
|
||||
|
||||
// Save training data points.
|
||||
Algorithm::Copy(coords, Coords);
|
||||
|
||||
// generate unique id for each input point
|
||||
vtkm::cont::ArrayHandleCounting<vtkm::Id> pointCounting(0, 1, coords.GetNumberOfValues());
|
||||
Algorithm::Copy(pointCounting, PointIds);
|
||||
@ -196,31 +217,17 @@ public:
|
||||
Algorithm::LowerBounds(CellIds, cell_ids_counting, CellLower);
|
||||
}
|
||||
|
||||
/// \brief Nearest neighbor search using a Uniform Grid
|
||||
///
|
||||
/// Parallel search of nearesat neighbor for each point in the \c queryPoints in the set of
|
||||
/// \c coords. Returns neareast neighbot in \c nearestNeighborIds and distances to nearest
|
||||
/// neighbor in \c distances.
|
||||
///
|
||||
/// \param coords Point coordinates for training dataset.
|
||||
/// \param queryPoints Point coordinates to query for nearest neighbors.
|
||||
/// \param nearestNeighborIds Neareast neighbor in the training dataset for each points in
|
||||
/// the test set
|
||||
/// \param distances Distance between query points and their nearest neighbors.
|
||||
/// \param device Tag for selecting device adapter.
|
||||
template <typename DeviceAdapter>
|
||||
void FindNearestPoint(const vtkm::cont::ArrayHandle<vtkm::Vec<T, 3>>& coords,
|
||||
const vtkm::cont::ArrayHandle<vtkm::Vec<T, 3>>& queryPoints,
|
||||
vtkm::cont::ArrayHandle<vtkm::Id>& nearestNeighborIds,
|
||||
vtkm::cont::ArrayHandle<T>& distances,
|
||||
DeviceAdapter)
|
||||
Locator<DeviceAdapter> PrepareForExecution(DeviceAdapter)
|
||||
{
|
||||
UniformGridSearch uniformGridSearch(Min, Max, Dims);
|
||||
|
||||
vtkm::worklet::DispatcherMapField<UniformGridSearch, DeviceAdapter> searchDispatcher(
|
||||
uniformGridSearch);
|
||||
searchDispatcher.Invoke(
|
||||
queryPoints, coords, PointIds, CellLower, CellUpper, nearestNeighborIds, distances);
|
||||
// TODO: lifetime of coords???
|
||||
return Locator<DeviceAdapter>(Min,
|
||||
Max,
|
||||
Dims,
|
||||
Coords.PrepareForInput(DeviceAdapter()),
|
||||
PointIds.PrepareForInput(DeviceAdapter()),
|
||||
CellLower.PrepareForInput(DeviceAdapter()),
|
||||
CellUpper.PrepareForInput(DeviceAdapter()));
|
||||
}
|
||||
|
||||
private:
|
||||
@ -228,6 +235,7 @@ private:
|
||||
vtkm::Vec<T, 3> Max;
|
||||
vtkm::Vec<vtkm::Id, 3> Dims;
|
||||
|
||||
vtkm::cont::ArrayHandle<vtkm::Vec<T, 3>> Coords;
|
||||
vtkm::cont::ArrayHandle<vtkm::Id> PointIds;
|
||||
vtkm::cont::ArrayHandle<vtkm::Id> CellIds;
|
||||
vtkm::cont::ArrayHandle<vtkm::Id> CellLower;
|
||||
|
@ -75,6 +75,30 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
class PointLocatorUniformGridWorklet : public vtkm::worklet::WorkletMapField
|
||||
{
|
||||
public:
|
||||
typedef void ControlSignature(FieldIn<> qcIn,
|
||||
ExecObject locator,
|
||||
FieldOut<> nnIdOut,
|
||||
FieldOut<> nnDistOut);
|
||||
|
||||
typedef void ExecutionSignature(_1, _2, _3, _4);
|
||||
|
||||
VTKM_CONT
|
||||
PointLocatorUniformGridWorklet() {}
|
||||
|
||||
// TODO: change IdType, it is used for other purpose.
|
||||
template <typename CoordiVecType, typename Locator, typename IdType, typename CoordiType>
|
||||
VTKM_EXEC void operator()(const CoordiVecType& qc,
|
||||
const Locator& locator,
|
||||
IdType& nnIdOut,
|
||||
CoordiType& nnDis) const
|
||||
{
|
||||
locator.FindNearestPoint(qc, nnIdOut, nnDis);
|
||||
};
|
||||
};
|
||||
|
||||
template <typename DeviceAdapter>
|
||||
class TestingPointLocatorUniformGrid
|
||||
{
|
||||
@ -97,10 +121,12 @@ public:
|
||||
}
|
||||
auto coordi_Handle = vtkm::cont::make_ArrayHandle(coordi);
|
||||
|
||||
vtkm::worklet::PointLocatorUniformGrid<vtkm::Float32> uniformGrid(
|
||||
vtkm::cont::PointLocatorUniformGrid<vtkm::Float32> uniformGrid(
|
||||
{ 0.0f, 0.0f, 0.0f }, { 10.0f, 10.0f, 10.0f }, { 5, 5, 5 });
|
||||
uniformGrid.Build(coordi_Handle, DeviceAdapter());
|
||||
uniformGrid.Build(coordi_Handle, VTKM_DEFAULT_DEVICE_ADAPTER_TAG());
|
||||
auto locator = uniformGrid.PrepareForExecution(VTKM_DEFAULT_DEVICE_ADAPTER_TAG());
|
||||
|
||||
///// randomly generate training points/////
|
||||
std::vector<vtkm::Vec<vtkm::Float32, 3>> qcVec;
|
||||
for (vtkm::Int32 i = 0; i < nTestingPoint; i++)
|
||||
{
|
||||
@ -111,9 +137,12 @@ public:
|
||||
vtkm::cont::ArrayHandle<vtkm::Id> nnId_Handle;
|
||||
vtkm::cont::ArrayHandle<vtkm::Float32> nnDis_Handle;
|
||||
|
||||
uniformGrid.FindNearestPoint(
|
||||
coordi_Handle, qc_Handle, nnId_Handle, nnDis_Handle, DeviceAdapter());
|
||||
PointLocatorUniformGridWorklet pointLocatorUniformGridWorklet;
|
||||
vtkm::worklet::DispatcherMapField<PointLocatorUniformGridWorklet> locatorDispatcher(
|
||||
pointLocatorUniformGridWorklet);
|
||||
locatorDispatcher.Invoke(qc_Handle, locator, nnId_Handle, nnDis_Handle);
|
||||
|
||||
// brute force
|
||||
vtkm::cont::ArrayHandle<vtkm::Id> bfnnId_Handle;
|
||||
vtkm::cont::ArrayHandle<vtkm::Float32> bfnnDis_Handle;
|
||||
NearestNeighborSearchBruteForce3DWorklet nnsbf3dWorklet;
|
||||
|
Loading…
Reference in New Issue
Block a user