Update Wavelet worklet to support a runtime device adapter id

This commit is contained in:
Robert Maynard 2019-06-11 16:19:08 -04:00
parent 0bf3c7ffcb
commit 99865f47d1
3 changed files with 119 additions and 130 deletions

@ -147,16 +147,6 @@ using AllCellList = vtkm::ListTagJoin<StructuredCellList, UnstructuredCellList>;
using CoordinateList = vtkm::ListTagBase<vtkm::Vec<vtkm::Float32, 3>, vtkm::Vec<vtkm::Float64, 3>>;
struct WaveletGeneratorDataFunctor
{
template <typename DeviceAdapter>
bool operator()(DeviceAdapter, vtkm::worklet::WaveletGenerator& gen)
{
InputDataSet = gen.GenerateDataSet<DeviceAdapter>();
return true;
}
};
class BenchmarkFilterPolicy : public vtkm::filter::PolicyBase<BenchmarkFilterPolicy>
{
public:
@ -1373,9 +1363,7 @@ int BenchmarkBody(int argc, char** argv, const vtkm::cont::InitializeResult& con
vtkm::worklet::WaveletGenerator gen;
gen.SetExtent({ 0 }, { waveletDim });
// WaveletGenerator needs a template device argument not a id to deduce the portal type.
WaveletGeneratorDataFunctor genFunctor;
vtkm::cont::TryExecuteOnDevice(config.Device, genFunctor, gen);
InputDataSet = gen.GenerateDataSet(config.Device);
}
if (tetra)

@ -13,10 +13,10 @@
#include <vtkm/Types.h>
#include <vtkm/cont/Algorithm.h>
#include <vtkm/cont/CellSetStructured.h>
#include <vtkm/cont/CoordinateSystem.h>
#include <vtkm/cont/DataSet.h>
#include <vtkm/cont/DeviceAdapterAlgorithm.h>
#include <vtkm/cont/Field.h>
#include <vtkm/exec/FunctorBase.h>
@ -29,6 +29,95 @@ namespace vtkm
namespace worklet
{
namespace wavelet
{
template <typename Device>
struct Worker : public vtkm::exec::FunctorBase
{
using OutputHandleType = vtkm::cont::ArrayHandle<vtkm::FloatDefault>;
using OutputPortalType = decltype(std::declval<OutputHandleType>().PrepareForOutput(0, Device()));
using Vec3F = vtkm::Vec<vtkm::FloatDefault, 3>;
Vec3F Center;
Vec3F Spacing;
Vec3F Frequency;
Vec3F Magnitude;
Vec3F MinimumPoint;
Vec3F Scale;
vtkm::Id3 Offset;
vtkm::Id3 Dims;
vtkm::FloatDefault MaximumValue;
vtkm::FloatDefault Temp2;
OutputPortalType Portal;
VTKM_CONT
Worker(OutputHandleType& output,
const Vec3F& center,
const Vec3F& spacing,
const Vec3F& frequency,
const Vec3F& magnitude,
const Vec3F& minimumPoint,
const Vec3F& scale,
const vtkm::Id3& offset,
const vtkm::Id3& dims,
vtkm::FloatDefault maximumValue,
vtkm::FloatDefault temp2)
: Center(center)
, Spacing(spacing)
, Frequency(frequency)
, Magnitude(magnitude)
, MinimumPoint(minimumPoint)
, Scale(scale)
, Offset(offset)
, Dims(dims)
, MaximumValue(maximumValue)
, Temp2(temp2)
, Portal(output.PrepareForOutput((dims[0] * dims[1] * dims[2]), Device{}))
{
}
VTKM_EXEC
void operator()(const vtkm::Id3& ijk) const
{
// map ijk to the point location, accounting for spacing:
const Vec3F loc = Vec3F(ijk + this->Offset) * this->Spacing;
// Compute the distance from the center of the gaussian:
const Vec3F scaledLoc = (this->Center - loc) * this->Scale;
vtkm::FloatDefault gaussSum = vtkm::Dot(scaledLoc, scaledLoc);
const Vec3F periodicContribs{
this->Magnitude[0] * vtkm::Sin(this->Frequency[0] * scaledLoc[0]),
this->Magnitude[1] * vtkm::Sin(this->Frequency[1] * scaledLoc[1]),
this->Magnitude[2] * vtkm::Cos(this->Frequency[2] * scaledLoc[2]),
};
// The vtkRTAnalyticSource documentation says the periodic contributions
// should be multiplied in, but the implementation adds them. We'll do as
// they do, not as they say.
const vtkm::FloatDefault scalar = this->MaximumValue * vtkm::Exp(-gaussSum * this->Temp2) +
periodicContribs[0] + periodicContribs[1] + periodicContribs[2];
// Compute output location
// (see ConnectivityStructuredInternals<3>::LogicalToFlatPointIndex)
const vtkm::Id scalarIdx = ijk[0] + this->Dims[0] * (ijk[1] + this->Dims[1] * ijk[2]);
this->Portal.Set(scalarIdx, scalar);
}
};
struct runWorker
{
template <typename Device, typename... Args>
inline bool operator()(Device, const vtkm::Id3 dims, Args... args) const
{
using Algo = vtkm::cont::DeviceAdapterAlgorithm<Device>;
Worker<Device> worker{ args... };
Algo::Schedule(worker, dims);
return true;
}
};
}
/**
* @brief The WaveletGenerator class creates a dataset similar to VTK's
* vtkRTAnalyticSource.
@ -80,81 +169,6 @@ class WaveletGenerator
FloatDefault StandardDeviation;
public:
template <typename Device>
struct Worker : public vtkm::exec::FunctorBase
{
using OutputHandleType = vtkm::cont::ArrayHandle<vtkm::FloatDefault>;
using OutputPortalType =
decltype(std::declval<OutputHandleType>().PrepareForOutput(0, Device()));
Vec3F Center;
Vec3F Spacing;
Vec3F Frequency;
Vec3F Magnitude;
Vec3F MinimumPoint;
Vec3F Scale;
vtkm::Id3 Offset;
vtkm::Id3 Dims;
vtkm::FloatDefault MaximumValue;
vtkm::FloatDefault Temp2;
OutputHandleType Output;
OutputPortalType Portal;
VTKM_CONT
Worker(const Vec3F& center,
const Vec3F& spacing,
const Vec3F& frequency,
const Vec3F& magnitude,
const Vec3F& minimumPoint,
const Vec3F& scale,
const vtkm::Id3& offset,
const vtkm::Id3& dims,
vtkm::FloatDefault maximumValue,
vtkm::FloatDefault temp2)
: Center(center)
, Spacing(spacing)
, Frequency(frequency)
, Magnitude(magnitude)
, MinimumPoint(minimumPoint)
, Scale(scale)
, Offset(offset)
, Dims(dims)
, MaximumValue(maximumValue)
, Temp2(temp2)
{
const vtkm::Id nVals = dims[0] * dims[1] * dims[2];
this->Portal = this->Output.PrepareForOutput(nVals, Device());
}
VTKM_EXEC
void operator()(const vtkm::Id3& ijk) const
{
// map ijk to the point location, accounting for spacing:
const Vec3F loc = Vec3F(ijk + this->Offset) * this->Spacing;
// Compute the distance from the center of the gaussian:
const Vec3F scaledLoc = (this->Center - loc) * this->Scale;
vtkm::FloatDefault gaussSum = vtkm::Dot(scaledLoc, scaledLoc);
const Vec3F periodicContribs{
this->Magnitude[0] * vtkm::Sin(this->Frequency[0] * scaledLoc[0]),
this->Magnitude[1] * vtkm::Sin(this->Frequency[1] * scaledLoc[1]),
this->Magnitude[2] * vtkm::Cos(this->Frequency[2] * scaledLoc[2]),
};
// The vtkRTAnalyticSource documentation says the periodic contributions
// should be multiplied in, but the implementation adds them. We'll do as
// they do, not as they say.
const vtkm::FloatDefault scalar = this->MaximumValue * vtkm::Exp(-gaussSum * this->Temp2) +
periodicContribs[0] + periodicContribs[1] + periodicContribs[2];
// Compute output location
// (see ConnectivityStructuredInternals<3>::LogicalToFlatPointIndex)
const vtkm::Id scalarIdx = ijk[0] + this->Dims[0] * (ijk[1] + this->Dims[1] * ijk[2]);
this->Portal.Set(scalarIdx, scalar);
}
};
VTKM_CONT
WaveletGenerator()
: Center{ 0. }
@ -199,8 +213,8 @@ public:
this->StandardDeviation = stdev;
}
template <typename Device>
VTKM_CONT vtkm::cont::DataSet GenerateDataSet(Device = Device())
VTKM_CONT vtkm::cont::DataSet GenerateDataSet(
vtkm::cont::DeviceAdapterId device = vtkm::cont::DeviceAdapterTagAny()) const
{
// Create points:
const vtkm::Id3 dims{ this->MaximumExtent - this->MinimumExtent + vtkm::Id3{ 1 } };
@ -212,7 +226,7 @@ public:
cellSet.SetPointDimensions(dims);
// Scalars, too
vtkm::cont::Field field = this->GenerateField<Device>("scalars");
vtkm::cont::Field field = this->GenerateField("scalars", device);
// Compile the dataset:
vtkm::cont::DataSet dataSet;
@ -223,11 +237,10 @@ public:
return dataSet;
}
template <typename Device>
VTKM_CONT vtkm::cont::Field GenerateField(const std::string& name, Device = Device())
VTKM_CONT vtkm::cont::Field GenerateField(
const std::string& name,
vtkm::cont::DeviceAdapterId device = vtkm::cont::DeviceAdapterTagAny()) const
{
using Algo = vtkm::cont::DeviceAdapterAlgorithm<Device>;
const vtkm::Id3 dims{ this->MaximumExtent - this->MinimumExtent + vtkm::Id3{ 1 } };
Vec3F minPt = Vec3F(this->MinimumExtent) * this->Spacing;
vtkm::FloatDefault temp2 = 1.f / (2.f * this->StandardDeviation * this->StandardDeviation);
@ -235,20 +248,23 @@ public:
ComputeScaleFactor(this->MinimumExtent[1], this->MaximumExtent[1]),
ComputeScaleFactor(this->MinimumExtent[2], this->MaximumExtent[2]) };
Worker<Device> worker{ this->Center,
this->Spacing,
this->Frequency,
this->Magnitude,
minPt,
scale,
this->MinimumExtent,
dims,
this->MaximumValue,
temp2 };
Algo::Schedule(worker, dims);
return vtkm::cont::Field(name, vtkm::cont::Field::Association::POINTS, worker.Output);
vtkm::cont::ArrayHandle<vtkm::FloatDefault> output;
vtkm::cont::TryExecuteOnDevice(device,
wavelet::runWorker{},
dims,
output,
this->Center,
this->Spacing,
this->Frequency,
this->Magnitude,
minPt,
scale,
this->MinimumExtent,
dims,
this->MaximumValue,
temp2);
return vtkm::cont::Field(name, vtkm::cont::Field::Association::POINTS, output);
}
private:

@ -13,28 +13,13 @@
#include <vtkm/cont/Timer.h>
#include <vtkm/cont/testing/Testing.h>
namespace detail
{
struct WaveletGeneratorFunctor
{
vtkm::cont::DataSet Ds;
template <typename DeviceAdapter>
bool operator()(DeviceAdapter)
{
vtkm::worklet::WaveletGenerator gen;
Ds = gen.GenerateDataSet<DeviceAdapter>();
return true;
}
};
}
void WaveletGeneratorTest()
{
vtkm::cont::Timer timer;
timer.Start();
detail::WaveletGeneratorFunctor wgFunctor;
vtkm::cont::TryExecute(wgFunctor);
vtkm::worklet::WaveletGenerator gen;
vtkm::cont::DataSet ds = gen.GenerateDataSet();
double time = timer.GetElapsedTime();
@ -42,13 +27,13 @@ void WaveletGeneratorTest()
std::cout << "Default wavelet took " << time << "s.\n";
{
auto coords = wgFunctor.Ds.GetCoordinateSystem("coords");
auto coords = ds.GetCoordinateSystem("coords");
auto data = coords.GetData();
VTKM_TEST_ASSERT(test_equal(data.GetNumberOfValues(), 9261), "Incorrect number of points.");
}
{
auto cells = wgFunctor.Ds.GetCellSet(wgFunctor.Ds.GetCellSetIndex("cells"));
auto cells = ds.GetCellSet(ds.GetCellSetIndex("cells"));
VTKM_TEST_ASSERT(test_equal(cells.GetNumberOfCells(), 8000), "Incorrect number of cells.");
}
@ -56,7 +41,7 @@ void WaveletGeneratorTest()
{
using ScalarHandleType = vtkm::cont::ArrayHandle<vtkm::FloatDefault>;
auto field = wgFunctor.Ds.GetField("scalars", vtkm::cont::Field::Association::POINTS);
auto field = ds.GetField("scalars", vtkm::cont::Field::Association::POINTS);
auto dynData = field.GetData();
VTKM_TEST_ASSERT(dynData.IsType<ScalarHandleType>(), "Invalid scalar handle type.");
ScalarHandleType handle = dynData.Cast<ScalarHandleType>();