Merge branch 'pointlocator' into pointlocator2

This commit is contained in:
Li-Ta Lo 2018-06-19 12:06:54 -06:00
commit 2e519f6508
2 changed files with 101 additions and 64 deletions

@ -28,7 +28,7 @@
namespace vtkm
{
namespace worklet
namespace cont
{
template <typename T>
class PointLocatorUniformGrid
@ -71,46 +71,56 @@ public:
vtkm::Vec<T, 3> Dxdydz;
};
class UniformGridSearch : public vtkm::worklet::WorkletMapField
template <typename DeviceAdapter>
class Locator : public vtkm::exec::ExecutionObjectBase
{
public:
using ControlSignature = void(FieldIn<> query,
WholeArrayIn<> coordIn,
WholeArrayIn<IdType> pointId,
WholeArrayIn<IdType> cellLower,
WholeArrayIn<IdType> cellUpper,
FieldOut<IdType> neighborId,
FieldOut<> distance);
using ExecutionSignature = void(_1, _2, _3, _4, _5, _6, _7);
using CoordPortalType = typename vtkm::cont::ArrayHandle<
vtkm::Vec<T, 3>>::template ExecutionTypes<DeviceAdapter>::PortalConst;
using IdPortalType = typename vtkm::cont::ArrayHandle<vtkm::Id>::template ExecutionTypes<
DeviceAdapter>::PortalConst;
VTKM_CONT
UniformGridSearch(const vtkm::Vec<T, 3>& _min,
const vtkm::Vec<T, 3>& _max,
const vtkm::Vec<vtkm::Id, 3>& _dims)
Locator() = default;
VTKM_CONT
Locator(const vtkm::Vec<T, 3>& _min,
const vtkm::Vec<T, 3>& _max,
const vtkm::Vec<vtkm::Id, 3>& _dims,
const CoordPortalType& coords,
const IdPortalType& pointIds,
const IdPortalType& cellLower,
const IdPortalType& cellUpper)
: Min(_min)
, Dims(_dims)
, Dxdydz((_max - _min) / _dims)
, Dxdydz((_max - Min) / Dims)
, coords(coords)
, pointIds(pointIds)
, cellLower(cellLower)
, cellUpper(cellUpper)
{
}
template <typename CoordiVecType,
typename IdPortalType,
typename CoordiPortalType,
typename IdType,
typename CoordiType>
VTKM_EXEC void operator()(const CoordiVecType& queryCoord,
const CoordiPortalType& coordi_Handle,
const IdPortalType& pointId,
const IdPortalType& cellLower,
const IdPortalType& cellUpper,
IdType& nnId,
CoordiType& nnDis) const
/// \brief Nearest neighbor search using a Uniform Grid
///
/// Parallel search of nearesat neighbor for each point in the \c queryPoints in the set of
/// \c coords. Returns neareast neighbot in \c nearestNeighborIds and distances to nearest
/// neighbor in \c distances.
///
/// \param coords Point coordinates for training dataset.
/// \param queryPoints Point coordinates to query for nearest neighbors.
/// \param nearestNeighborIds Neareast neighbor in the training dataset for each points in
/// the test set
/// \param distances Distance between query points and their nearest neighbors.
/// \param device Tag for selecting device adapter.
VTKM_EXEC
void FindNearestPoint(const vtkm::Vec<T, 3>& queryPoint,
vtkm::Id& nearestNeighborId,
T& distance) const
{
auto nlayers = vtkm::Max(vtkm::Max(Dims[0], Dims[1]), Dims[2]);
vtkm::Vec<vtkm::Id, 3> xyz = (queryCoord - Min) / Dxdydz;
vtkm::Vec<vtkm::Id, 3> xyz = (queryPoint - Min) / Dxdydz;
float min_distance = std::numeric_limits<float>::max();
vtkm::Id neareast = -1;
@ -138,11 +148,11 @@ public:
auto upper = cellUpper.Get(cellid);
for (auto index = lower; index < upper; index++)
{
auto pointid = pointId.Get(index);
auto point = coordi_Handle.Get(pointid);
auto dx = point[0] - queryCoord[0];
auto dy = point[1] - queryCoord[1];
auto dz = point[2] - queryCoord[2];
auto pointid = pointIds.Get(index);
auto point = coords.Get(pointid);
auto dx = point[0] - queryPoint[0];
auto dy = point[1] - queryPoint[1];
auto dz = point[2] - queryPoint[2];
auto distance2 = dx * dx + dy * dy + dz * dz;
if (distance2 < min_distance)
@ -158,14 +168,21 @@ public:
}
}
nnId = neareast;
nnDis = vtkm::Sqrt(min_distance);
}
nearestNeighborId = neareast;
distance = vtkm::Sqrt(min_distance);
};
private:
vtkm::Vec<T, 3> Min;
vtkm::Vec<vtkm::Id, 3> Dims;
vtkm::Vec<T, 3> Dxdydz;
CoordPortalType coords;
IdPortalType pointIds;
IdPortalType cellIds;
IdPortalType cellLower;
IdPortalType cellUpper;
};
/// \brief Construct a 3D uniform grid for nearest neighbor search.
@ -177,6 +194,10 @@ public:
DeviceAdapter vtkmNotUsed(device))
{
using Algorithm = vtkm::cont::DeviceAdapterAlgorithm<DeviceAdapter>;
// Save training data points.
Algorithm::Copy(coords, Coords);
// generate unique id for each input point
vtkm::cont::ArrayHandleCounting<vtkm::Id> pointCounting(0, 1, coords.GetNumberOfValues());
Algorithm::Copy(pointCounting, PointIds);
@ -196,31 +217,17 @@ public:
Algorithm::LowerBounds(CellIds, cell_ids_counting, CellLower);
}
/// \brief Nearest neighbor search using a Uniform Grid
///
/// Parallel search of nearesat neighbor for each point in the \c queryPoints in the set of
/// \c coords. Returns neareast neighbot in \c nearestNeighborIds and distances to nearest
/// neighbor in \c distances.
///
/// \param coords Point coordinates for training dataset.
/// \param queryPoints Point coordinates to query for nearest neighbors.
/// \param nearestNeighborIds Neareast neighbor in the training dataset for each points in
/// the test set
/// \param distances Distance between query points and their nearest neighbors.
/// \param device Tag for selecting device adapter.
template <typename DeviceAdapter>
void FindNearestPoint(const vtkm::cont::ArrayHandle<vtkm::Vec<T, 3>>& coords,
const vtkm::cont::ArrayHandle<vtkm::Vec<T, 3>>& queryPoints,
vtkm::cont::ArrayHandle<vtkm::Id>& nearestNeighborIds,
vtkm::cont::ArrayHandle<T>& distances,
DeviceAdapter)
Locator<DeviceAdapter> PrepareForExecution(DeviceAdapter)
{
UniformGridSearch uniformGridSearch(Min, Max, Dims);
vtkm::worklet::DispatcherMapField<UniformGridSearch, DeviceAdapter> searchDispatcher(
uniformGridSearch);
searchDispatcher.Invoke(
queryPoints, coords, PointIds, CellLower, CellUpper, nearestNeighborIds, distances);
// TODO: lifetime of coords???
return Locator<DeviceAdapter>(Min,
Max,
Dims,
Coords.PrepareForInput(DeviceAdapter()),
PointIds.PrepareForInput(DeviceAdapter()),
CellLower.PrepareForInput(DeviceAdapter()),
CellUpper.PrepareForInput(DeviceAdapter()));
}
private:
@ -228,6 +235,7 @@ private:
vtkm::Vec<T, 3> Max;
vtkm::Vec<vtkm::Id, 3> Dims;
vtkm::cont::ArrayHandle<vtkm::Vec<T, 3>> Coords;
vtkm::cont::ArrayHandle<vtkm::Id> PointIds;
vtkm::cont::ArrayHandle<vtkm::Id> CellIds;
vtkm::cont::ArrayHandle<vtkm::Id> CellLower;

@ -75,6 +75,30 @@ public:
}
};
class PointLocatorUniformGridWorklet : public vtkm::worklet::WorkletMapField
{
public:
typedef void ControlSignature(FieldIn<> qcIn,
ExecObject locator,
FieldOut<> nnIdOut,
FieldOut<> nnDistOut);
typedef void ExecutionSignature(_1, _2, _3, _4);
VTKM_CONT
PointLocatorUniformGridWorklet() {}
// TODO: change IdType, it is used for other purpose.
template <typename CoordiVecType, typename Locator, typename IdType, typename CoordiType>
VTKM_EXEC void operator()(const CoordiVecType& qc,
const Locator& locator,
IdType& nnIdOut,
CoordiType& nnDis) const
{
locator.FindNearestPoint(qc, nnIdOut, nnDis);
};
};
template <typename DeviceAdapter>
class TestingPointLocatorUniformGrid
{
@ -97,10 +121,12 @@ public:
}
auto coordi_Handle = vtkm::cont::make_ArrayHandle(coordi);
vtkm::worklet::PointLocatorUniformGrid<vtkm::Float32> uniformGrid(
vtkm::cont::PointLocatorUniformGrid<vtkm::Float32> uniformGrid(
{ 0.0f, 0.0f, 0.0f }, { 10.0f, 10.0f, 10.0f }, { 5, 5, 5 });
uniformGrid.Build(coordi_Handle, DeviceAdapter());
uniformGrid.Build(coordi_Handle, VTKM_DEFAULT_DEVICE_ADAPTER_TAG());
auto locator = uniformGrid.PrepareForExecution(VTKM_DEFAULT_DEVICE_ADAPTER_TAG());
///// randomly generate training points/////
std::vector<vtkm::Vec<vtkm::Float32, 3>> qcVec;
for (vtkm::Int32 i = 0; i < nTestingPoint; i++)
{
@ -111,9 +137,12 @@ public:
vtkm::cont::ArrayHandle<vtkm::Id> nnId_Handle;
vtkm::cont::ArrayHandle<vtkm::Float32> nnDis_Handle;
uniformGrid.FindNearestPoint(
coordi_Handle, qc_Handle, nnId_Handle, nnDis_Handle, DeviceAdapter());
PointLocatorUniformGridWorklet pointLocatorUniformGridWorklet;
vtkm::worklet::DispatcherMapField<PointLocatorUniformGridWorklet> locatorDispatcher(
pointLocatorUniformGridWorklet);
locatorDispatcher.Invoke(qc_Handle, locator, nnId_Handle, nnDis_Handle);
// brute force
vtkm::cont::ArrayHandle<vtkm::Id> bfnnId_Handle;
vtkm::cont::ArrayHandle<vtkm::Float32> bfnnDis_Handle;
NearestNeighborSearchBruteForce3DWorklet nnsbf3dWorklet;