Update Wavelet to use WorkletVisitPointsWithCells

The Wavelet worklet doesn't need a custom scheduler but
can use WorkletVisitPointsWithCells
This commit is contained in:
Robert Maynard 2019-09-13 15:07:40 -04:00
parent 62357dea0f
commit 5c69f1bae9
2 changed files with 30 additions and 51 deletions

@ -25,11 +25,13 @@ namespace source
{
namespace wavelet
{
template <typename Device>
struct Worker : public vtkm::exec::FunctorBase
struct WaveletField : public vtkm::worklet::WorkletVisitPointsWithCells
{
using OutputHandleType = vtkm::cont::ArrayHandle<vtkm::FloatDefault>;
using OutputPortalType = decltype(std::declval<OutputHandleType>().PrepareForOutput(0, Device()));
using ControlSignature = void(CellSetIn, FieldOut v);
using ExecutionSignature = void(ThreadIndices, _2);
using InputDomain = _1;
using Vec3F = vtkm::Vec3f;
Vec3F Center;
@ -42,20 +44,17 @@ struct Worker : public vtkm::exec::FunctorBase
vtkm::Id3 Dims;
vtkm::FloatDefault MaximumValue;
vtkm::FloatDefault Temp2;
OutputPortalType Portal;
VTKM_CONT
Worker(OutputHandleType& output,
const Vec3F& center,
const Vec3F& spacing,
const Vec3F& frequency,
const Vec3F& magnitude,
const Vec3F& minimumPoint,
const Vec3F& scale,
const vtkm::Id3& offset,
const vtkm::Id3& dims,
vtkm::FloatDefault maximumValue,
vtkm::FloatDefault temp2)
WaveletField(const Vec3F& center,
const Vec3F& spacing,
const Vec3F& frequency,
const Vec3F& magnitude,
const Vec3F& minimumPoint,
const Vec3F& scale,
const vtkm::Id3& offset,
const vtkm::Id3& dims,
vtkm::FloatDefault maximumValue,
vtkm::FloatDefault temp2)
: Center(center)
, Spacing(spacing)
, Frequency(frequency)
@ -66,13 +65,14 @@ struct Worker : public vtkm::exec::FunctorBase
, Dims(dims)
, MaximumValue(maximumValue)
, Temp2(temp2)
, Portal(output.PrepareForOutput((dims[0] * dims[1] * dims[2]), Device{}))
{
}
VTKM_EXEC
void operator()(const vtkm::Id3& ijk) const
template <typename ThreadIndexType>
VTKM_EXEC void operator()(const ThreadIndexType& threadIndex, vtkm::FloatDefault& scalar) const
{
const vtkm::Id3 ijk = threadIndex.GetInputIndex3D();
// map ijk to the point location, accounting for spacing:
const Vec3F loc = Vec3F(ijk + this->Offset) * this->Spacing;
@ -89,25 +89,8 @@ struct Worker : public vtkm::exec::FunctorBase
// The vtkRTAnalyticSource documentation says the periodic contributions
// should be multiplied in, but the implementation adds them. We'll do as
// they do, not as they say.
const vtkm::FloatDefault scalar = this->MaximumValue * vtkm::Exp(-gaussSum * this->Temp2) +
periodicContribs[0] + periodicContribs[1] + periodicContribs[2];
// Compute output location
// (see ConnectivityStructuredInternals<3>::LogicalToFlatPointIndex)
const vtkm::Id scalarIdx = ijk[0] + this->Dims[0] * (ijk[1] + this->Dims[1] * ijk[2]);
this->Portal.Set(scalarIdx, scalar);
}
};
struct runWorker
{
template <typename Device, typename... Args>
inline bool operator()(Device, const vtkm::Id3 dims, Args... args) const
{
using Algo = vtkm::cont::DeviceAdapterAlgorithm<Device>;
Worker<Device> worker{ args... };
Algo::Schedule(worker, dims);
return true;
scalar =
this->MaximumValue * vtkm::Exp(-gaussSum * this->Temp2) + vtkm::ReduceSum(periodicContribs);
}
};
} // namespace wavelet
@ -143,16 +126,15 @@ vtkm::cont::DataSet Wavelet::Execute() const
dataSet.SetCellSet(cellSet);
// Scalars, too
vtkm::cont::Field field = this->GeneratePointField("scalars");
vtkm::cont::Field field = this->GeneratePointField(cellSet, "scalars");
dataSet.AddField(field);
return dataSet;
}
vtkm::cont::Field Wavelet::GeneratePointField(const std::string& name) const
vtkm::cont::Field Wavelet::GeneratePointField(const vtkm::cont::CellSetStructured<3>& cellset,
const std::string& name) const
{
VTKM_LOG_SCOPE_FUNCTION(vtkm::cont::LogLevel::Perf);
const vtkm::Id3 dims{ this->MaximumExtent - this->MinimumExtent + vtkm::Id3{ 1 } };
vtkm::Vec3f minPt = vtkm::Vec3f(this->MinimumExtent) * this->Spacing;
vtkm::FloatDefault temp2 = 1.f / (2.f * this->StandardDeviation * this->StandardDeviation);
@ -160,13 +142,8 @@ vtkm::cont::Field Wavelet::GeneratePointField(const std::string& name) const
computeScaleFactor(this->MinimumExtent[1], this->MaximumExtent[1]),
computeScaleFactor(this->MinimumExtent[2], this->MaximumExtent[2]) };
vtkm::cont::ArrayHandle<vtkm::FloatDefault> output;
vtkm::cont::TryExecuteOnDevice(this->Invoke.GetDevice(),
wavelet::runWorker{},
dims,
output,
this->Center,
wavelet::WaveletField worklet{ this->Center,
this->Spacing,
this->Frequency,
this->Magnitude,
@ -175,7 +152,8 @@ vtkm::cont::Field Wavelet::GeneratePointField(const std::string& name) const
this->MinimumExtent,
dims,
this->MaximumValue,
temp2);
temp2 };
this->Invoke(worklet, cellset, output);
return vtkm::cont::make_FieldPoint(name, output);
}

@ -94,7 +94,8 @@ public:
vtkm::cont::DataSet Execute() const;
private:
vtkm::cont::Field GeneratePointField(const std::string& name) const;
vtkm::cont::Field GeneratePointField(const vtkm::cont::CellSetStructured<3>& cellset,
const std::string& name) const;
vtkm::Vec3f Center;
vtkm::Vec3f Spacing;