mirror of
https://gitlab.kitware.com/vtk/vtk-m
synced 2024-09-18 18:15:44 +00:00
Update Wavelet worklet to support a runtime device adapter id
This commit is contained in:
parent
0bf3c7ffcb
commit
99865f47d1
@ -147,16 +147,6 @@ using AllCellList = vtkm::ListTagJoin<StructuredCellList, UnstructuredCellList>;
|
||||
|
||||
using CoordinateList = vtkm::ListTagBase<vtkm::Vec<vtkm::Float32, 3>, vtkm::Vec<vtkm::Float64, 3>>;
|
||||
|
||||
struct WaveletGeneratorDataFunctor
|
||||
{
|
||||
template <typename DeviceAdapter>
|
||||
bool operator()(DeviceAdapter, vtkm::worklet::WaveletGenerator& gen)
|
||||
{
|
||||
InputDataSet = gen.GenerateDataSet<DeviceAdapter>();
|
||||
return true;
|
||||
}
|
||||
};
|
||||
|
||||
class BenchmarkFilterPolicy : public vtkm::filter::PolicyBase<BenchmarkFilterPolicy>
|
||||
{
|
||||
public:
|
||||
@ -1373,9 +1363,7 @@ int BenchmarkBody(int argc, char** argv, const vtkm::cont::InitializeResult& con
|
||||
vtkm::worklet::WaveletGenerator gen;
|
||||
gen.SetExtent({ 0 }, { waveletDim });
|
||||
|
||||
// WaveletGenerator needs a template device argument not a id to deduce the portal type.
|
||||
WaveletGeneratorDataFunctor genFunctor;
|
||||
vtkm::cont::TryExecuteOnDevice(config.Device, genFunctor, gen);
|
||||
InputDataSet = gen.GenerateDataSet(config.Device);
|
||||
}
|
||||
|
||||
if (tetra)
|
||||
|
@ -13,10 +13,10 @@
|
||||
|
||||
#include <vtkm/Types.h>
|
||||
|
||||
#include <vtkm/cont/Algorithm.h>
|
||||
#include <vtkm/cont/CellSetStructured.h>
|
||||
#include <vtkm/cont/CoordinateSystem.h>
|
||||
#include <vtkm/cont/DataSet.h>
|
||||
#include <vtkm/cont/DeviceAdapterAlgorithm.h>
|
||||
#include <vtkm/cont/Field.h>
|
||||
|
||||
#include <vtkm/exec/FunctorBase.h>
|
||||
@ -29,6 +29,95 @@ namespace vtkm
|
||||
namespace worklet
|
||||
{
|
||||
|
||||
namespace wavelet
|
||||
{
|
||||
template <typename Device>
|
||||
struct Worker : public vtkm::exec::FunctorBase
|
||||
{
|
||||
using OutputHandleType = vtkm::cont::ArrayHandle<vtkm::FloatDefault>;
|
||||
using OutputPortalType = decltype(std::declval<OutputHandleType>().PrepareForOutput(0, Device()));
|
||||
using Vec3F = vtkm::Vec<vtkm::FloatDefault, 3>;
|
||||
|
||||
Vec3F Center;
|
||||
Vec3F Spacing;
|
||||
Vec3F Frequency;
|
||||
Vec3F Magnitude;
|
||||
Vec3F MinimumPoint;
|
||||
Vec3F Scale;
|
||||
vtkm::Id3 Offset;
|
||||
vtkm::Id3 Dims;
|
||||
vtkm::FloatDefault MaximumValue;
|
||||
vtkm::FloatDefault Temp2;
|
||||
OutputPortalType Portal;
|
||||
|
||||
VTKM_CONT
|
||||
Worker(OutputHandleType& output,
|
||||
const Vec3F& center,
|
||||
const Vec3F& spacing,
|
||||
const Vec3F& frequency,
|
||||
const Vec3F& magnitude,
|
||||
const Vec3F& minimumPoint,
|
||||
const Vec3F& scale,
|
||||
const vtkm::Id3& offset,
|
||||
const vtkm::Id3& dims,
|
||||
vtkm::FloatDefault maximumValue,
|
||||
vtkm::FloatDefault temp2)
|
||||
: Center(center)
|
||||
, Spacing(spacing)
|
||||
, Frequency(frequency)
|
||||
, Magnitude(magnitude)
|
||||
, MinimumPoint(minimumPoint)
|
||||
, Scale(scale)
|
||||
, Offset(offset)
|
||||
, Dims(dims)
|
||||
, MaximumValue(maximumValue)
|
||||
, Temp2(temp2)
|
||||
, Portal(output.PrepareForOutput((dims[0] * dims[1] * dims[2]), Device{}))
|
||||
{
|
||||
}
|
||||
|
||||
VTKM_EXEC
|
||||
void operator()(const vtkm::Id3& ijk) const
|
||||
{
|
||||
// map ijk to the point location, accounting for spacing:
|
||||
const Vec3F loc = Vec3F(ijk + this->Offset) * this->Spacing;
|
||||
|
||||
// Compute the distance from the center of the gaussian:
|
||||
const Vec3F scaledLoc = (this->Center - loc) * this->Scale;
|
||||
vtkm::FloatDefault gaussSum = vtkm::Dot(scaledLoc, scaledLoc);
|
||||
|
||||
const Vec3F periodicContribs{
|
||||
this->Magnitude[0] * vtkm::Sin(this->Frequency[0] * scaledLoc[0]),
|
||||
this->Magnitude[1] * vtkm::Sin(this->Frequency[1] * scaledLoc[1]),
|
||||
this->Magnitude[2] * vtkm::Cos(this->Frequency[2] * scaledLoc[2]),
|
||||
};
|
||||
|
||||
// The vtkRTAnalyticSource documentation says the periodic contributions
|
||||
// should be multiplied in, but the implementation adds them. We'll do as
|
||||
// they do, not as they say.
|
||||
const vtkm::FloatDefault scalar = this->MaximumValue * vtkm::Exp(-gaussSum * this->Temp2) +
|
||||
periodicContribs[0] + periodicContribs[1] + periodicContribs[2];
|
||||
|
||||
// Compute output location
|
||||
// (see ConnectivityStructuredInternals<3>::LogicalToFlatPointIndex)
|
||||
const vtkm::Id scalarIdx = ijk[0] + this->Dims[0] * (ijk[1] + this->Dims[1] * ijk[2]);
|
||||
this->Portal.Set(scalarIdx, scalar);
|
||||
}
|
||||
};
|
||||
|
||||
struct runWorker
|
||||
{
|
||||
template <typename Device, typename... Args>
|
||||
inline bool operator()(Device, const vtkm::Id3 dims, Args... args) const
|
||||
{
|
||||
using Algo = vtkm::cont::DeviceAdapterAlgorithm<Device>;
|
||||
Worker<Device> worker{ args... };
|
||||
Algo::Schedule(worker, dims);
|
||||
return true;
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief The WaveletGenerator class creates a dataset similar to VTK's
|
||||
* vtkRTAnalyticSource.
|
||||
@ -80,81 +169,6 @@ class WaveletGenerator
|
||||
FloatDefault StandardDeviation;
|
||||
|
||||
public:
|
||||
template <typename Device>
|
||||
struct Worker : public vtkm::exec::FunctorBase
|
||||
{
|
||||
using OutputHandleType = vtkm::cont::ArrayHandle<vtkm::FloatDefault>;
|
||||
using OutputPortalType =
|
||||
decltype(std::declval<OutputHandleType>().PrepareForOutput(0, Device()));
|
||||
|
||||
Vec3F Center;
|
||||
Vec3F Spacing;
|
||||
Vec3F Frequency;
|
||||
Vec3F Magnitude;
|
||||
Vec3F MinimumPoint;
|
||||
Vec3F Scale;
|
||||
vtkm::Id3 Offset;
|
||||
vtkm::Id3 Dims;
|
||||
vtkm::FloatDefault MaximumValue;
|
||||
vtkm::FloatDefault Temp2;
|
||||
OutputHandleType Output;
|
||||
OutputPortalType Portal;
|
||||
|
||||
VTKM_CONT
|
||||
Worker(const Vec3F& center,
|
||||
const Vec3F& spacing,
|
||||
const Vec3F& frequency,
|
||||
const Vec3F& magnitude,
|
||||
const Vec3F& minimumPoint,
|
||||
const Vec3F& scale,
|
||||
const vtkm::Id3& offset,
|
||||
const vtkm::Id3& dims,
|
||||
vtkm::FloatDefault maximumValue,
|
||||
vtkm::FloatDefault temp2)
|
||||
: Center(center)
|
||||
, Spacing(spacing)
|
||||
, Frequency(frequency)
|
||||
, Magnitude(magnitude)
|
||||
, MinimumPoint(minimumPoint)
|
||||
, Scale(scale)
|
||||
, Offset(offset)
|
||||
, Dims(dims)
|
||||
, MaximumValue(maximumValue)
|
||||
, Temp2(temp2)
|
||||
{
|
||||
const vtkm::Id nVals = dims[0] * dims[1] * dims[2];
|
||||
this->Portal = this->Output.PrepareForOutput(nVals, Device());
|
||||
}
|
||||
|
||||
VTKM_EXEC
|
||||
void operator()(const vtkm::Id3& ijk) const
|
||||
{
|
||||
// map ijk to the point location, accounting for spacing:
|
||||
const Vec3F loc = Vec3F(ijk + this->Offset) * this->Spacing;
|
||||
|
||||
// Compute the distance from the center of the gaussian:
|
||||
const Vec3F scaledLoc = (this->Center - loc) * this->Scale;
|
||||
vtkm::FloatDefault gaussSum = vtkm::Dot(scaledLoc, scaledLoc);
|
||||
|
||||
const Vec3F periodicContribs{
|
||||
this->Magnitude[0] * vtkm::Sin(this->Frequency[0] * scaledLoc[0]),
|
||||
this->Magnitude[1] * vtkm::Sin(this->Frequency[1] * scaledLoc[1]),
|
||||
this->Magnitude[2] * vtkm::Cos(this->Frequency[2] * scaledLoc[2]),
|
||||
};
|
||||
|
||||
// The vtkRTAnalyticSource documentation says the periodic contributions
|
||||
// should be multiplied in, but the implementation adds them. We'll do as
|
||||
// they do, not as they say.
|
||||
const vtkm::FloatDefault scalar = this->MaximumValue * vtkm::Exp(-gaussSum * this->Temp2) +
|
||||
periodicContribs[0] + periodicContribs[1] + periodicContribs[2];
|
||||
|
||||
// Compute output location
|
||||
// (see ConnectivityStructuredInternals<3>::LogicalToFlatPointIndex)
|
||||
const vtkm::Id scalarIdx = ijk[0] + this->Dims[0] * (ijk[1] + this->Dims[1] * ijk[2]);
|
||||
this->Portal.Set(scalarIdx, scalar);
|
||||
}
|
||||
};
|
||||
|
||||
VTKM_CONT
|
||||
WaveletGenerator()
|
||||
: Center{ 0. }
|
||||
@ -199,8 +213,8 @@ public:
|
||||
this->StandardDeviation = stdev;
|
||||
}
|
||||
|
||||
template <typename Device>
|
||||
VTKM_CONT vtkm::cont::DataSet GenerateDataSet(Device = Device())
|
||||
VTKM_CONT vtkm::cont::DataSet GenerateDataSet(
|
||||
vtkm::cont::DeviceAdapterId device = vtkm::cont::DeviceAdapterTagAny()) const
|
||||
{
|
||||
// Create points:
|
||||
const vtkm::Id3 dims{ this->MaximumExtent - this->MinimumExtent + vtkm::Id3{ 1 } };
|
||||
@ -212,7 +226,7 @@ public:
|
||||
cellSet.SetPointDimensions(dims);
|
||||
|
||||
// Scalars, too
|
||||
vtkm::cont::Field field = this->GenerateField<Device>("scalars");
|
||||
vtkm::cont::Field field = this->GenerateField("scalars", device);
|
||||
|
||||
// Compile the dataset:
|
||||
vtkm::cont::DataSet dataSet;
|
||||
@ -223,11 +237,10 @@ public:
|
||||
return dataSet;
|
||||
}
|
||||
|
||||
template <typename Device>
|
||||
VTKM_CONT vtkm::cont::Field GenerateField(const std::string& name, Device = Device())
|
||||
VTKM_CONT vtkm::cont::Field GenerateField(
|
||||
const std::string& name,
|
||||
vtkm::cont::DeviceAdapterId device = vtkm::cont::DeviceAdapterTagAny()) const
|
||||
{
|
||||
using Algo = vtkm::cont::DeviceAdapterAlgorithm<Device>;
|
||||
|
||||
const vtkm::Id3 dims{ this->MaximumExtent - this->MinimumExtent + vtkm::Id3{ 1 } };
|
||||
Vec3F minPt = Vec3F(this->MinimumExtent) * this->Spacing;
|
||||
vtkm::FloatDefault temp2 = 1.f / (2.f * this->StandardDeviation * this->StandardDeviation);
|
||||
@ -235,7 +248,13 @@ public:
|
||||
ComputeScaleFactor(this->MinimumExtent[1], this->MaximumExtent[1]),
|
||||
ComputeScaleFactor(this->MinimumExtent[2], this->MaximumExtent[2]) };
|
||||
|
||||
Worker<Device> worker{ this->Center,
|
||||
|
||||
vtkm::cont::ArrayHandle<vtkm::FloatDefault> output;
|
||||
vtkm::cont::TryExecuteOnDevice(device,
|
||||
wavelet::runWorker{},
|
||||
dims,
|
||||
output,
|
||||
this->Center,
|
||||
this->Spacing,
|
||||
this->Frequency,
|
||||
this->Magnitude,
|
||||
@ -244,11 +263,8 @@ public:
|
||||
this->MinimumExtent,
|
||||
dims,
|
||||
this->MaximumValue,
|
||||
temp2 };
|
||||
|
||||
Algo::Schedule(worker, dims);
|
||||
|
||||
return vtkm::cont::Field(name, vtkm::cont::Field::Association::POINTS, worker.Output);
|
||||
temp2);
|
||||
return vtkm::cont::Field(name, vtkm::cont::Field::Association::POINTS, output);
|
||||
}
|
||||
|
||||
private:
|
||||
|
@ -13,28 +13,13 @@
|
||||
#include <vtkm/cont/Timer.h>
|
||||
#include <vtkm/cont/testing/Testing.h>
|
||||
|
||||
namespace detail
|
||||
{
|
||||
|
||||
struct WaveletGeneratorFunctor
|
||||
{
|
||||
vtkm::cont::DataSet Ds;
|
||||
template <typename DeviceAdapter>
|
||||
bool operator()(DeviceAdapter)
|
||||
{
|
||||
vtkm::worklet::WaveletGenerator gen;
|
||||
Ds = gen.GenerateDataSet<DeviceAdapter>();
|
||||
return true;
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
void WaveletGeneratorTest()
|
||||
{
|
||||
vtkm::cont::Timer timer;
|
||||
timer.Start();
|
||||
detail::WaveletGeneratorFunctor wgFunctor;
|
||||
vtkm::cont::TryExecute(wgFunctor);
|
||||
|
||||
vtkm::worklet::WaveletGenerator gen;
|
||||
vtkm::cont::DataSet ds = gen.GenerateDataSet();
|
||||
|
||||
|
||||
double time = timer.GetElapsedTime();
|
||||
@ -42,13 +27,13 @@ void WaveletGeneratorTest()
|
||||
std::cout << "Default wavelet took " << time << "s.\n";
|
||||
|
||||
{
|
||||
auto coords = wgFunctor.Ds.GetCoordinateSystem("coords");
|
||||
auto coords = ds.GetCoordinateSystem("coords");
|
||||
auto data = coords.GetData();
|
||||
VTKM_TEST_ASSERT(test_equal(data.GetNumberOfValues(), 9261), "Incorrect number of points.");
|
||||
}
|
||||
|
||||
{
|
||||
auto cells = wgFunctor.Ds.GetCellSet(wgFunctor.Ds.GetCellSetIndex("cells"));
|
||||
auto cells = ds.GetCellSet(ds.GetCellSetIndex("cells"));
|
||||
VTKM_TEST_ASSERT(test_equal(cells.GetNumberOfCells(), 8000), "Incorrect number of cells.");
|
||||
}
|
||||
|
||||
@ -56,7 +41,7 @@ void WaveletGeneratorTest()
|
||||
{
|
||||
using ScalarHandleType = vtkm::cont::ArrayHandle<vtkm::FloatDefault>;
|
||||
|
||||
auto field = wgFunctor.Ds.GetField("scalars", vtkm::cont::Field::Association::POINTS);
|
||||
auto field = ds.GetField("scalars", vtkm::cont::Field::Association::POINTS);
|
||||
auto dynData = field.GetData();
|
||||
VTKM_TEST_ASSERT(dynData.IsType<ScalarHandleType>(), "Invalid scalar handle type.");
|
||||
ScalarHandleType handle = dynData.Cast<ScalarHandleType>();
|
||||
|
Loading…
Reference in New Issue
Block a user