mirror of
https://gitlab.kitware.com/vtk/vtk-m
synced 2024-09-19 10:35:42 +00:00
Update Wavelet worklet to support a runtime device adapter id
This commit is contained in:
parent
0bf3c7ffcb
commit
99865f47d1
@ -147,16 +147,6 @@ using AllCellList = vtkm::ListTagJoin<StructuredCellList, UnstructuredCellList>;
|
|||||||
|
|
||||||
using CoordinateList = vtkm::ListTagBase<vtkm::Vec<vtkm::Float32, 3>, vtkm::Vec<vtkm::Float64, 3>>;
|
using CoordinateList = vtkm::ListTagBase<vtkm::Vec<vtkm::Float32, 3>, vtkm::Vec<vtkm::Float64, 3>>;
|
||||||
|
|
||||||
struct WaveletGeneratorDataFunctor
|
|
||||||
{
|
|
||||||
template <typename DeviceAdapter>
|
|
||||||
bool operator()(DeviceAdapter, vtkm::worklet::WaveletGenerator& gen)
|
|
||||||
{
|
|
||||||
InputDataSet = gen.GenerateDataSet<DeviceAdapter>();
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
class BenchmarkFilterPolicy : public vtkm::filter::PolicyBase<BenchmarkFilterPolicy>
|
class BenchmarkFilterPolicy : public vtkm::filter::PolicyBase<BenchmarkFilterPolicy>
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
@ -1373,9 +1363,7 @@ int BenchmarkBody(int argc, char** argv, const vtkm::cont::InitializeResult& con
|
|||||||
vtkm::worklet::WaveletGenerator gen;
|
vtkm::worklet::WaveletGenerator gen;
|
||||||
gen.SetExtent({ 0 }, { waveletDim });
|
gen.SetExtent({ 0 }, { waveletDim });
|
||||||
|
|
||||||
// WaveletGenerator needs a template device argument not a id to deduce the portal type.
|
InputDataSet = gen.GenerateDataSet(config.Device);
|
||||||
WaveletGeneratorDataFunctor genFunctor;
|
|
||||||
vtkm::cont::TryExecuteOnDevice(config.Device, genFunctor, gen);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if (tetra)
|
if (tetra)
|
||||||
|
@ -13,10 +13,10 @@
|
|||||||
|
|
||||||
#include <vtkm/Types.h>
|
#include <vtkm/Types.h>
|
||||||
|
|
||||||
|
#include <vtkm/cont/Algorithm.h>
|
||||||
#include <vtkm/cont/CellSetStructured.h>
|
#include <vtkm/cont/CellSetStructured.h>
|
||||||
#include <vtkm/cont/CoordinateSystem.h>
|
#include <vtkm/cont/CoordinateSystem.h>
|
||||||
#include <vtkm/cont/DataSet.h>
|
#include <vtkm/cont/DataSet.h>
|
||||||
#include <vtkm/cont/DeviceAdapterAlgorithm.h>
|
|
||||||
#include <vtkm/cont/Field.h>
|
#include <vtkm/cont/Field.h>
|
||||||
|
|
||||||
#include <vtkm/exec/FunctorBase.h>
|
#include <vtkm/exec/FunctorBase.h>
|
||||||
@ -29,6 +29,95 @@ namespace vtkm
|
|||||||
namespace worklet
|
namespace worklet
|
||||||
{
|
{
|
||||||
|
|
||||||
|
namespace wavelet
|
||||||
|
{
|
||||||
|
template <typename Device>
|
||||||
|
struct Worker : public vtkm::exec::FunctorBase
|
||||||
|
{
|
||||||
|
using OutputHandleType = vtkm::cont::ArrayHandle<vtkm::FloatDefault>;
|
||||||
|
using OutputPortalType = decltype(std::declval<OutputHandleType>().PrepareForOutput(0, Device()));
|
||||||
|
using Vec3F = vtkm::Vec<vtkm::FloatDefault, 3>;
|
||||||
|
|
||||||
|
Vec3F Center;
|
||||||
|
Vec3F Spacing;
|
||||||
|
Vec3F Frequency;
|
||||||
|
Vec3F Magnitude;
|
||||||
|
Vec3F MinimumPoint;
|
||||||
|
Vec3F Scale;
|
||||||
|
vtkm::Id3 Offset;
|
||||||
|
vtkm::Id3 Dims;
|
||||||
|
vtkm::FloatDefault MaximumValue;
|
||||||
|
vtkm::FloatDefault Temp2;
|
||||||
|
OutputPortalType Portal;
|
||||||
|
|
||||||
|
VTKM_CONT
|
||||||
|
Worker(OutputHandleType& output,
|
||||||
|
const Vec3F& center,
|
||||||
|
const Vec3F& spacing,
|
||||||
|
const Vec3F& frequency,
|
||||||
|
const Vec3F& magnitude,
|
||||||
|
const Vec3F& minimumPoint,
|
||||||
|
const Vec3F& scale,
|
||||||
|
const vtkm::Id3& offset,
|
||||||
|
const vtkm::Id3& dims,
|
||||||
|
vtkm::FloatDefault maximumValue,
|
||||||
|
vtkm::FloatDefault temp2)
|
||||||
|
: Center(center)
|
||||||
|
, Spacing(spacing)
|
||||||
|
, Frequency(frequency)
|
||||||
|
, Magnitude(magnitude)
|
||||||
|
, MinimumPoint(minimumPoint)
|
||||||
|
, Scale(scale)
|
||||||
|
, Offset(offset)
|
||||||
|
, Dims(dims)
|
||||||
|
, MaximumValue(maximumValue)
|
||||||
|
, Temp2(temp2)
|
||||||
|
, Portal(output.PrepareForOutput((dims[0] * dims[1] * dims[2]), Device{}))
|
||||||
|
{
|
||||||
|
}
|
||||||
|
|
||||||
|
VTKM_EXEC
|
||||||
|
void operator()(const vtkm::Id3& ijk) const
|
||||||
|
{
|
||||||
|
// map ijk to the point location, accounting for spacing:
|
||||||
|
const Vec3F loc = Vec3F(ijk + this->Offset) * this->Spacing;
|
||||||
|
|
||||||
|
// Compute the distance from the center of the gaussian:
|
||||||
|
const Vec3F scaledLoc = (this->Center - loc) * this->Scale;
|
||||||
|
vtkm::FloatDefault gaussSum = vtkm::Dot(scaledLoc, scaledLoc);
|
||||||
|
|
||||||
|
const Vec3F periodicContribs{
|
||||||
|
this->Magnitude[0] * vtkm::Sin(this->Frequency[0] * scaledLoc[0]),
|
||||||
|
this->Magnitude[1] * vtkm::Sin(this->Frequency[1] * scaledLoc[1]),
|
||||||
|
this->Magnitude[2] * vtkm::Cos(this->Frequency[2] * scaledLoc[2]),
|
||||||
|
};
|
||||||
|
|
||||||
|
// The vtkRTAnalyticSource documentation says the periodic contributions
|
||||||
|
// should be multiplied in, but the implementation adds them. We'll do as
|
||||||
|
// they do, not as they say.
|
||||||
|
const vtkm::FloatDefault scalar = this->MaximumValue * vtkm::Exp(-gaussSum * this->Temp2) +
|
||||||
|
periodicContribs[0] + periodicContribs[1] + periodicContribs[2];
|
||||||
|
|
||||||
|
// Compute output location
|
||||||
|
// (see ConnectivityStructuredInternals<3>::LogicalToFlatPointIndex)
|
||||||
|
const vtkm::Id scalarIdx = ijk[0] + this->Dims[0] * (ijk[1] + this->Dims[1] * ijk[2]);
|
||||||
|
this->Portal.Set(scalarIdx, scalar);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct runWorker
|
||||||
|
{
|
||||||
|
template <typename Device, typename... Args>
|
||||||
|
inline bool operator()(Device, const vtkm::Id3 dims, Args... args) const
|
||||||
|
{
|
||||||
|
using Algo = vtkm::cont::DeviceAdapterAlgorithm<Device>;
|
||||||
|
Worker<Device> worker{ args... };
|
||||||
|
Algo::Schedule(worker, dims);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief The WaveletGenerator class creates a dataset similar to VTK's
|
* @brief The WaveletGenerator class creates a dataset similar to VTK's
|
||||||
* vtkRTAnalyticSource.
|
* vtkRTAnalyticSource.
|
||||||
@ -80,81 +169,6 @@ class WaveletGenerator
|
|||||||
FloatDefault StandardDeviation;
|
FloatDefault StandardDeviation;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
template <typename Device>
|
|
||||||
struct Worker : public vtkm::exec::FunctorBase
|
|
||||||
{
|
|
||||||
using OutputHandleType = vtkm::cont::ArrayHandle<vtkm::FloatDefault>;
|
|
||||||
using OutputPortalType =
|
|
||||||
decltype(std::declval<OutputHandleType>().PrepareForOutput(0, Device()));
|
|
||||||
|
|
||||||
Vec3F Center;
|
|
||||||
Vec3F Spacing;
|
|
||||||
Vec3F Frequency;
|
|
||||||
Vec3F Magnitude;
|
|
||||||
Vec3F MinimumPoint;
|
|
||||||
Vec3F Scale;
|
|
||||||
vtkm::Id3 Offset;
|
|
||||||
vtkm::Id3 Dims;
|
|
||||||
vtkm::FloatDefault MaximumValue;
|
|
||||||
vtkm::FloatDefault Temp2;
|
|
||||||
OutputHandleType Output;
|
|
||||||
OutputPortalType Portal;
|
|
||||||
|
|
||||||
VTKM_CONT
|
|
||||||
Worker(const Vec3F& center,
|
|
||||||
const Vec3F& spacing,
|
|
||||||
const Vec3F& frequency,
|
|
||||||
const Vec3F& magnitude,
|
|
||||||
const Vec3F& minimumPoint,
|
|
||||||
const Vec3F& scale,
|
|
||||||
const vtkm::Id3& offset,
|
|
||||||
const vtkm::Id3& dims,
|
|
||||||
vtkm::FloatDefault maximumValue,
|
|
||||||
vtkm::FloatDefault temp2)
|
|
||||||
: Center(center)
|
|
||||||
, Spacing(spacing)
|
|
||||||
, Frequency(frequency)
|
|
||||||
, Magnitude(magnitude)
|
|
||||||
, MinimumPoint(minimumPoint)
|
|
||||||
, Scale(scale)
|
|
||||||
, Offset(offset)
|
|
||||||
, Dims(dims)
|
|
||||||
, MaximumValue(maximumValue)
|
|
||||||
, Temp2(temp2)
|
|
||||||
{
|
|
||||||
const vtkm::Id nVals = dims[0] * dims[1] * dims[2];
|
|
||||||
this->Portal = this->Output.PrepareForOutput(nVals, Device());
|
|
||||||
}
|
|
||||||
|
|
||||||
VTKM_EXEC
|
|
||||||
void operator()(const vtkm::Id3& ijk) const
|
|
||||||
{
|
|
||||||
// map ijk to the point location, accounting for spacing:
|
|
||||||
const Vec3F loc = Vec3F(ijk + this->Offset) * this->Spacing;
|
|
||||||
|
|
||||||
// Compute the distance from the center of the gaussian:
|
|
||||||
const Vec3F scaledLoc = (this->Center - loc) * this->Scale;
|
|
||||||
vtkm::FloatDefault gaussSum = vtkm::Dot(scaledLoc, scaledLoc);
|
|
||||||
|
|
||||||
const Vec3F periodicContribs{
|
|
||||||
this->Magnitude[0] * vtkm::Sin(this->Frequency[0] * scaledLoc[0]),
|
|
||||||
this->Magnitude[1] * vtkm::Sin(this->Frequency[1] * scaledLoc[1]),
|
|
||||||
this->Magnitude[2] * vtkm::Cos(this->Frequency[2] * scaledLoc[2]),
|
|
||||||
};
|
|
||||||
|
|
||||||
// The vtkRTAnalyticSource documentation says the periodic contributions
|
|
||||||
// should be multiplied in, but the implementation adds them. We'll do as
|
|
||||||
// they do, not as they say.
|
|
||||||
const vtkm::FloatDefault scalar = this->MaximumValue * vtkm::Exp(-gaussSum * this->Temp2) +
|
|
||||||
periodicContribs[0] + periodicContribs[1] + periodicContribs[2];
|
|
||||||
|
|
||||||
// Compute output location
|
|
||||||
// (see ConnectivityStructuredInternals<3>::LogicalToFlatPointIndex)
|
|
||||||
const vtkm::Id scalarIdx = ijk[0] + this->Dims[0] * (ijk[1] + this->Dims[1] * ijk[2]);
|
|
||||||
this->Portal.Set(scalarIdx, scalar);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
VTKM_CONT
|
VTKM_CONT
|
||||||
WaveletGenerator()
|
WaveletGenerator()
|
||||||
: Center{ 0. }
|
: Center{ 0. }
|
||||||
@ -199,8 +213,8 @@ public:
|
|||||||
this->StandardDeviation = stdev;
|
this->StandardDeviation = stdev;
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename Device>
|
VTKM_CONT vtkm::cont::DataSet GenerateDataSet(
|
||||||
VTKM_CONT vtkm::cont::DataSet GenerateDataSet(Device = Device())
|
vtkm::cont::DeviceAdapterId device = vtkm::cont::DeviceAdapterTagAny()) const
|
||||||
{
|
{
|
||||||
// Create points:
|
// Create points:
|
||||||
const vtkm::Id3 dims{ this->MaximumExtent - this->MinimumExtent + vtkm::Id3{ 1 } };
|
const vtkm::Id3 dims{ this->MaximumExtent - this->MinimumExtent + vtkm::Id3{ 1 } };
|
||||||
@ -212,7 +226,7 @@ public:
|
|||||||
cellSet.SetPointDimensions(dims);
|
cellSet.SetPointDimensions(dims);
|
||||||
|
|
||||||
// Scalars, too
|
// Scalars, too
|
||||||
vtkm::cont::Field field = this->GenerateField<Device>("scalars");
|
vtkm::cont::Field field = this->GenerateField("scalars", device);
|
||||||
|
|
||||||
// Compile the dataset:
|
// Compile the dataset:
|
||||||
vtkm::cont::DataSet dataSet;
|
vtkm::cont::DataSet dataSet;
|
||||||
@ -223,11 +237,10 @@ public:
|
|||||||
return dataSet;
|
return dataSet;
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename Device>
|
VTKM_CONT vtkm::cont::Field GenerateField(
|
||||||
VTKM_CONT vtkm::cont::Field GenerateField(const std::string& name, Device = Device())
|
const std::string& name,
|
||||||
|
vtkm::cont::DeviceAdapterId device = vtkm::cont::DeviceAdapterTagAny()) const
|
||||||
{
|
{
|
||||||
using Algo = vtkm::cont::DeviceAdapterAlgorithm<Device>;
|
|
||||||
|
|
||||||
const vtkm::Id3 dims{ this->MaximumExtent - this->MinimumExtent + vtkm::Id3{ 1 } };
|
const vtkm::Id3 dims{ this->MaximumExtent - this->MinimumExtent + vtkm::Id3{ 1 } };
|
||||||
Vec3F minPt = Vec3F(this->MinimumExtent) * this->Spacing;
|
Vec3F minPt = Vec3F(this->MinimumExtent) * this->Spacing;
|
||||||
vtkm::FloatDefault temp2 = 1.f / (2.f * this->StandardDeviation * this->StandardDeviation);
|
vtkm::FloatDefault temp2 = 1.f / (2.f * this->StandardDeviation * this->StandardDeviation);
|
||||||
@ -235,7 +248,13 @@ public:
|
|||||||
ComputeScaleFactor(this->MinimumExtent[1], this->MaximumExtent[1]),
|
ComputeScaleFactor(this->MinimumExtent[1], this->MaximumExtent[1]),
|
||||||
ComputeScaleFactor(this->MinimumExtent[2], this->MaximumExtent[2]) };
|
ComputeScaleFactor(this->MinimumExtent[2], this->MaximumExtent[2]) };
|
||||||
|
|
||||||
Worker<Device> worker{ this->Center,
|
|
||||||
|
vtkm::cont::ArrayHandle<vtkm::FloatDefault> output;
|
||||||
|
vtkm::cont::TryExecuteOnDevice(device,
|
||||||
|
wavelet::runWorker{},
|
||||||
|
dims,
|
||||||
|
output,
|
||||||
|
this->Center,
|
||||||
this->Spacing,
|
this->Spacing,
|
||||||
this->Frequency,
|
this->Frequency,
|
||||||
this->Magnitude,
|
this->Magnitude,
|
||||||
@ -244,11 +263,8 @@ public:
|
|||||||
this->MinimumExtent,
|
this->MinimumExtent,
|
||||||
dims,
|
dims,
|
||||||
this->MaximumValue,
|
this->MaximumValue,
|
||||||
temp2 };
|
temp2);
|
||||||
|
return vtkm::cont::Field(name, vtkm::cont::Field::Association::POINTS, output);
|
||||||
Algo::Schedule(worker, dims);
|
|
||||||
|
|
||||||
return vtkm::cont::Field(name, vtkm::cont::Field::Association::POINTS, worker.Output);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
@ -13,28 +13,13 @@
|
|||||||
#include <vtkm/cont/Timer.h>
|
#include <vtkm/cont/Timer.h>
|
||||||
#include <vtkm/cont/testing/Testing.h>
|
#include <vtkm/cont/testing/Testing.h>
|
||||||
|
|
||||||
namespace detail
|
|
||||||
{
|
|
||||||
|
|
||||||
struct WaveletGeneratorFunctor
|
|
||||||
{
|
|
||||||
vtkm::cont::DataSet Ds;
|
|
||||||
template <typename DeviceAdapter>
|
|
||||||
bool operator()(DeviceAdapter)
|
|
||||||
{
|
|
||||||
vtkm::worklet::WaveletGenerator gen;
|
|
||||||
Ds = gen.GenerateDataSet<DeviceAdapter>();
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
void WaveletGeneratorTest()
|
void WaveletGeneratorTest()
|
||||||
{
|
{
|
||||||
vtkm::cont::Timer timer;
|
vtkm::cont::Timer timer;
|
||||||
timer.Start();
|
timer.Start();
|
||||||
detail::WaveletGeneratorFunctor wgFunctor;
|
|
||||||
vtkm::cont::TryExecute(wgFunctor);
|
vtkm::worklet::WaveletGenerator gen;
|
||||||
|
vtkm::cont::DataSet ds = gen.GenerateDataSet();
|
||||||
|
|
||||||
|
|
||||||
double time = timer.GetElapsedTime();
|
double time = timer.GetElapsedTime();
|
||||||
@ -42,13 +27,13 @@ void WaveletGeneratorTest()
|
|||||||
std::cout << "Default wavelet took " << time << "s.\n";
|
std::cout << "Default wavelet took " << time << "s.\n";
|
||||||
|
|
||||||
{
|
{
|
||||||
auto coords = wgFunctor.Ds.GetCoordinateSystem("coords");
|
auto coords = ds.GetCoordinateSystem("coords");
|
||||||
auto data = coords.GetData();
|
auto data = coords.GetData();
|
||||||
VTKM_TEST_ASSERT(test_equal(data.GetNumberOfValues(), 9261), "Incorrect number of points.");
|
VTKM_TEST_ASSERT(test_equal(data.GetNumberOfValues(), 9261), "Incorrect number of points.");
|
||||||
}
|
}
|
||||||
|
|
||||||
{
|
{
|
||||||
auto cells = wgFunctor.Ds.GetCellSet(wgFunctor.Ds.GetCellSetIndex("cells"));
|
auto cells = ds.GetCellSet(ds.GetCellSetIndex("cells"));
|
||||||
VTKM_TEST_ASSERT(test_equal(cells.GetNumberOfCells(), 8000), "Incorrect number of cells.");
|
VTKM_TEST_ASSERT(test_equal(cells.GetNumberOfCells(), 8000), "Incorrect number of cells.");
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -56,7 +41,7 @@ void WaveletGeneratorTest()
|
|||||||
{
|
{
|
||||||
using ScalarHandleType = vtkm::cont::ArrayHandle<vtkm::FloatDefault>;
|
using ScalarHandleType = vtkm::cont::ArrayHandle<vtkm::FloatDefault>;
|
||||||
|
|
||||||
auto field = wgFunctor.Ds.GetField("scalars", vtkm::cont::Field::Association::POINTS);
|
auto field = ds.GetField("scalars", vtkm::cont::Field::Association::POINTS);
|
||||||
auto dynData = field.GetData();
|
auto dynData = field.GetData();
|
||||||
VTKM_TEST_ASSERT(dynData.IsType<ScalarHandleType>(), "Invalid scalar handle type.");
|
VTKM_TEST_ASSERT(dynData.IsType<ScalarHandleType>(), "Invalid scalar handle type.");
|
||||||
ScalarHandleType handle = dynData.Cast<ScalarHandleType>();
|
ScalarHandleType handle = dynData.Cast<ScalarHandleType>();
|
||||||
|
Loading…
Reference in New Issue
Block a user