From 5c69f1bae938915fd6e123eed97fca8e8037c3aa Mon Sep 17 00:00:00 2001 From: Robert Maynard Date: Fri, 13 Sep 2019 15:07:40 -0400 Subject: [PATCH] Update Wavelet to use WorkletVisitPointsWithCells The Wavelet worklet doesn't need a custom scheduler but can use WorkletVisitPointsWithCells --- vtkm/source/Wavelet.cxx | 78 +++++++++++++++-------------------------- vtkm/source/Wavelet.h | 3 +- 2 files changed, 30 insertions(+), 51 deletions(-) diff --git a/vtkm/source/Wavelet.cxx b/vtkm/source/Wavelet.cxx index 8d00ccf22..6c0fc042b 100644 --- a/vtkm/source/Wavelet.cxx +++ b/vtkm/source/Wavelet.cxx @@ -25,11 +25,13 @@ namespace source { namespace wavelet { -template -struct Worker : public vtkm::exec::FunctorBase + +struct WaveletField : public vtkm::worklet::WorkletVisitPointsWithCells { - using OutputHandleType = vtkm::cont::ArrayHandle; - using OutputPortalType = decltype(std::declval().PrepareForOutput(0, Device())); + using ControlSignature = void(CellSetIn, FieldOut v); + using ExecutionSignature = void(ThreadIndices, _2); + using InputDomain = _1; + using Vec3F = vtkm::Vec3f; Vec3F Center; @@ -42,20 +44,17 @@ struct Worker : public vtkm::exec::FunctorBase vtkm::Id3 Dims; vtkm::FloatDefault MaximumValue; vtkm::FloatDefault Temp2; - OutputPortalType Portal; - VTKM_CONT - Worker(OutputHandleType& output, - const Vec3F& center, - const Vec3F& spacing, - const Vec3F& frequency, - const Vec3F& magnitude, - const Vec3F& minimumPoint, - const Vec3F& scale, - const vtkm::Id3& offset, - const vtkm::Id3& dims, - vtkm::FloatDefault maximumValue, - vtkm::FloatDefault temp2) + WaveletField(const Vec3F& center, + const Vec3F& spacing, + const Vec3F& frequency, + const Vec3F& magnitude, + const Vec3F& minimumPoint, + const Vec3F& scale, + const vtkm::Id3& offset, + const vtkm::Id3& dims, + vtkm::FloatDefault maximumValue, + vtkm::FloatDefault temp2) : Center(center) , Spacing(spacing) , Frequency(frequency) @@ -66,13 +65,14 @@ struct Worker : public vtkm::exec::FunctorBase , Dims(dims) , MaximumValue(maximumValue) , Temp2(temp2) - , Portal(output.PrepareForOutput((dims[0] * dims[1] * dims[2]), Device{})) { } - VTKM_EXEC - void operator()(const vtkm::Id3& ijk) const + template + VTKM_EXEC void operator()(const ThreadIndexType& threadIndex, vtkm::FloatDefault& scalar) const { + const vtkm::Id3 ijk = threadIndex.GetInputIndex3D(); + // map ijk to the point location, accounting for spacing: const Vec3F loc = Vec3F(ijk + this->Offset) * this->Spacing; @@ -89,25 +89,8 @@ struct Worker : public vtkm::exec::FunctorBase // The vtkRTAnalyticSource documentation says the periodic contributions // should be multiplied in, but the implementation adds them. We'll do as // they do, not as they say. - const vtkm::FloatDefault scalar = this->MaximumValue * vtkm::Exp(-gaussSum * this->Temp2) + - periodicContribs[0] + periodicContribs[1] + periodicContribs[2]; - - // Compute output location - // (see ConnectivityStructuredInternals<3>::LogicalToFlatPointIndex) - const vtkm::Id scalarIdx = ijk[0] + this->Dims[0] * (ijk[1] + this->Dims[1] * ijk[2]); - this->Portal.Set(scalarIdx, scalar); - } -}; - -struct runWorker -{ - template - inline bool operator()(Device, const vtkm::Id3 dims, Args... args) const - { - using Algo = vtkm::cont::DeviceAdapterAlgorithm; - Worker worker{ args... }; - Algo::Schedule(worker, dims); - return true; + scalar = + this->MaximumValue * vtkm::Exp(-gaussSum * this->Temp2) + vtkm::ReduceSum(periodicContribs); } }; } // namespace wavelet @@ -143,16 +126,15 @@ vtkm::cont::DataSet Wavelet::Execute() const dataSet.SetCellSet(cellSet); // Scalars, too - vtkm::cont::Field field = this->GeneratePointField("scalars"); + vtkm::cont::Field field = this->GeneratePointField(cellSet, "scalars"); dataSet.AddField(field); return dataSet; } -vtkm::cont::Field Wavelet::GeneratePointField(const std::string& name) const +vtkm::cont::Field Wavelet::GeneratePointField(const vtkm::cont::CellSetStructured<3>& cellset, + const std::string& name) const { - VTKM_LOG_SCOPE_FUNCTION(vtkm::cont::LogLevel::Perf); - const vtkm::Id3 dims{ this->MaximumExtent - this->MinimumExtent + vtkm::Id3{ 1 } }; vtkm::Vec3f minPt = vtkm::Vec3f(this->MinimumExtent) * this->Spacing; vtkm::FloatDefault temp2 = 1.f / (2.f * this->StandardDeviation * this->StandardDeviation); @@ -160,13 +142,8 @@ vtkm::cont::Field Wavelet::GeneratePointField(const std::string& name) const computeScaleFactor(this->MinimumExtent[1], this->MaximumExtent[1]), computeScaleFactor(this->MinimumExtent[2], this->MaximumExtent[2]) }; - vtkm::cont::ArrayHandle output; - vtkm::cont::TryExecuteOnDevice(this->Invoke.GetDevice(), - wavelet::runWorker{}, - dims, - output, - this->Center, + wavelet::WaveletField worklet{ this->Center, this->Spacing, this->Frequency, this->Magnitude, @@ -175,7 +152,8 @@ vtkm::cont::Field Wavelet::GeneratePointField(const std::string& name) const this->MinimumExtent, dims, this->MaximumValue, - temp2); + temp2 }; + this->Invoke(worklet, cellset, output); return vtkm::cont::make_FieldPoint(name, output); } diff --git a/vtkm/source/Wavelet.h b/vtkm/source/Wavelet.h index bfc6fff41..dfcac1931 100644 --- a/vtkm/source/Wavelet.h +++ b/vtkm/source/Wavelet.h @@ -94,7 +94,8 @@ public: vtkm::cont::DataSet Execute() const; private: - vtkm::cont::Field GeneratePointField(const std::string& name) const; + vtkm::cont::Field GeneratePointField(const vtkm::cont::CellSetStructured<3>& cellset, + const std::string& name) const; vtkm::Vec3f Center; vtkm::Vec3f Spacing;