diff --git a/benchmarking/BenchmarkFilters.cxx b/benchmarking/BenchmarkFilters.cxx index f235b9427..6b8684c35 100644 --- a/benchmarking/BenchmarkFilters.cxx +++ b/benchmarking/BenchmarkFilters.cxx @@ -147,16 +147,6 @@ using AllCellList = vtkm::ListTagJoin; using CoordinateList = vtkm::ListTagBase, vtkm::Vec>; -struct WaveletGeneratorDataFunctor -{ - template - bool operator()(DeviceAdapter, vtkm::worklet::WaveletGenerator& gen) - { - InputDataSet = gen.GenerateDataSet(); - return true; - } -}; - class BenchmarkFilterPolicy : public vtkm::filter::PolicyBase { public: @@ -1373,9 +1363,7 @@ int BenchmarkBody(int argc, char** argv, const vtkm::cont::InitializeResult& con vtkm::worklet::WaveletGenerator gen; gen.SetExtent({ 0 }, { waveletDim }); - // WaveletGenerator needs a template device argument not a id to deduce the portal type. - WaveletGeneratorDataFunctor genFunctor; - vtkm::cont::TryExecuteOnDevice(config.Device, genFunctor, gen); + InputDataSet = gen.GenerateDataSet(config.Device); } if (tetra) diff --git a/vtkm/worklet/WaveletGenerator.h b/vtkm/worklet/WaveletGenerator.h index 8e8b23de7..b11533d28 100644 --- a/vtkm/worklet/WaveletGenerator.h +++ b/vtkm/worklet/WaveletGenerator.h @@ -13,10 +13,10 @@ #include +#include #include #include #include -#include #include #include @@ -29,6 +29,95 @@ namespace vtkm namespace worklet { +namespace wavelet +{ +template +struct Worker : public vtkm::exec::FunctorBase +{ + using OutputHandleType = vtkm::cont::ArrayHandle; + using OutputPortalType = decltype(std::declval().PrepareForOutput(0, Device())); + using Vec3F = vtkm::Vec; + + Vec3F Center; + Vec3F Spacing; + Vec3F Frequency; + Vec3F Magnitude; + Vec3F MinimumPoint; + Vec3F Scale; + vtkm::Id3 Offset; + vtkm::Id3 Dims; + vtkm::FloatDefault MaximumValue; + vtkm::FloatDefault Temp2; + OutputPortalType Portal; + + VTKM_CONT + Worker(OutputHandleType& output, + const Vec3F& center, + const Vec3F& spacing, + const Vec3F& frequency, + const Vec3F& magnitude, + const Vec3F& minimumPoint, + const Vec3F& scale, + const vtkm::Id3& offset, + const vtkm::Id3& dims, + vtkm::FloatDefault maximumValue, + vtkm::FloatDefault temp2) + : Center(center) + , Spacing(spacing) + , Frequency(frequency) + , Magnitude(magnitude) + , MinimumPoint(minimumPoint) + , Scale(scale) + , Offset(offset) + , Dims(dims) + , MaximumValue(maximumValue) + , Temp2(temp2) + , Portal(output.PrepareForOutput((dims[0] * dims[1] * dims[2]), Device{})) + { + } + + VTKM_EXEC + void operator()(const vtkm::Id3& ijk) const + { + // map ijk to the point location, accounting for spacing: + const Vec3F loc = Vec3F(ijk + this->Offset) * this->Spacing; + + // Compute the distance from the center of the gaussian: + const Vec3F scaledLoc = (this->Center - loc) * this->Scale; + vtkm::FloatDefault gaussSum = vtkm::Dot(scaledLoc, scaledLoc); + + const Vec3F periodicContribs{ + this->Magnitude[0] * vtkm::Sin(this->Frequency[0] * scaledLoc[0]), + this->Magnitude[1] * vtkm::Sin(this->Frequency[1] * scaledLoc[1]), + this->Magnitude[2] * vtkm::Cos(this->Frequency[2] * scaledLoc[2]), + }; + + // The vtkRTAnalyticSource documentation says the periodic contributions + // should be multiplied in, but the implementation adds them. We'll do as + // they do, not as they say. + const vtkm::FloatDefault scalar = this->MaximumValue * vtkm::Exp(-gaussSum * this->Temp2) + + periodicContribs[0] + periodicContribs[1] + periodicContribs[2]; + + // Compute output location + // (see ConnectivityStructuredInternals<3>::LogicalToFlatPointIndex) + const vtkm::Id scalarIdx = ijk[0] + this->Dims[0] * (ijk[1] + this->Dims[1] * ijk[2]); + this->Portal.Set(scalarIdx, scalar); + } +}; + +struct runWorker +{ + template + inline bool operator()(Device, const vtkm::Id3 dims, Args... args) const + { + using Algo = vtkm::cont::DeviceAdapterAlgorithm; + Worker worker{ args... }; + Algo::Schedule(worker, dims); + return true; + } +}; +} + /** * @brief The WaveletGenerator class creates a dataset similar to VTK's * vtkRTAnalyticSource. @@ -80,81 +169,6 @@ class WaveletGenerator FloatDefault StandardDeviation; public: - template - struct Worker : public vtkm::exec::FunctorBase - { - using OutputHandleType = vtkm::cont::ArrayHandle; - using OutputPortalType = - decltype(std::declval().PrepareForOutput(0, Device())); - - Vec3F Center; - Vec3F Spacing; - Vec3F Frequency; - Vec3F Magnitude; - Vec3F MinimumPoint; - Vec3F Scale; - vtkm::Id3 Offset; - vtkm::Id3 Dims; - vtkm::FloatDefault MaximumValue; - vtkm::FloatDefault Temp2; - OutputHandleType Output; - OutputPortalType Portal; - - VTKM_CONT - Worker(const Vec3F& center, - const Vec3F& spacing, - const Vec3F& frequency, - const Vec3F& magnitude, - const Vec3F& minimumPoint, - const Vec3F& scale, - const vtkm::Id3& offset, - const vtkm::Id3& dims, - vtkm::FloatDefault maximumValue, - vtkm::FloatDefault temp2) - : Center(center) - , Spacing(spacing) - , Frequency(frequency) - , Magnitude(magnitude) - , MinimumPoint(minimumPoint) - , Scale(scale) - , Offset(offset) - , Dims(dims) - , MaximumValue(maximumValue) - , Temp2(temp2) - { - const vtkm::Id nVals = dims[0] * dims[1] * dims[2]; - this->Portal = this->Output.PrepareForOutput(nVals, Device()); - } - - VTKM_EXEC - void operator()(const vtkm::Id3& ijk) const - { - // map ijk to the point location, accounting for spacing: - const Vec3F loc = Vec3F(ijk + this->Offset) * this->Spacing; - - // Compute the distance from the center of the gaussian: - const Vec3F scaledLoc = (this->Center - loc) * this->Scale; - vtkm::FloatDefault gaussSum = vtkm::Dot(scaledLoc, scaledLoc); - - const Vec3F periodicContribs{ - this->Magnitude[0] * vtkm::Sin(this->Frequency[0] * scaledLoc[0]), - this->Magnitude[1] * vtkm::Sin(this->Frequency[1] * scaledLoc[1]), - this->Magnitude[2] * vtkm::Cos(this->Frequency[2] * scaledLoc[2]), - }; - - // The vtkRTAnalyticSource documentation says the periodic contributions - // should be multiplied in, but the implementation adds them. We'll do as - // they do, not as they say. - const vtkm::FloatDefault scalar = this->MaximumValue * vtkm::Exp(-gaussSum * this->Temp2) + - periodicContribs[0] + periodicContribs[1] + periodicContribs[2]; - - // Compute output location - // (see ConnectivityStructuredInternals<3>::LogicalToFlatPointIndex) - const vtkm::Id scalarIdx = ijk[0] + this->Dims[0] * (ijk[1] + this->Dims[1] * ijk[2]); - this->Portal.Set(scalarIdx, scalar); - } - }; - VTKM_CONT WaveletGenerator() : Center{ 0. } @@ -199,8 +213,8 @@ public: this->StandardDeviation = stdev; } - template - VTKM_CONT vtkm::cont::DataSet GenerateDataSet(Device = Device()) + VTKM_CONT vtkm::cont::DataSet GenerateDataSet( + vtkm::cont::DeviceAdapterId device = vtkm::cont::DeviceAdapterTagAny()) const { // Create points: const vtkm::Id3 dims{ this->MaximumExtent - this->MinimumExtent + vtkm::Id3{ 1 } }; @@ -212,7 +226,7 @@ public: cellSet.SetPointDimensions(dims); // Scalars, too - vtkm::cont::Field field = this->GenerateField("scalars"); + vtkm::cont::Field field = this->GenerateField("scalars", device); // Compile the dataset: vtkm::cont::DataSet dataSet; @@ -223,11 +237,10 @@ public: return dataSet; } - template - VTKM_CONT vtkm::cont::Field GenerateField(const std::string& name, Device = Device()) + VTKM_CONT vtkm::cont::Field GenerateField( + const std::string& name, + vtkm::cont::DeviceAdapterId device = vtkm::cont::DeviceAdapterTagAny()) const { - using Algo = vtkm::cont::DeviceAdapterAlgorithm; - const vtkm::Id3 dims{ this->MaximumExtent - this->MinimumExtent + vtkm::Id3{ 1 } }; Vec3F minPt = Vec3F(this->MinimumExtent) * this->Spacing; vtkm::FloatDefault temp2 = 1.f / (2.f * this->StandardDeviation * this->StandardDeviation); @@ -235,20 +248,23 @@ public: ComputeScaleFactor(this->MinimumExtent[1], this->MaximumExtent[1]), ComputeScaleFactor(this->MinimumExtent[2], this->MaximumExtent[2]) }; - Worker worker{ this->Center, - this->Spacing, - this->Frequency, - this->Magnitude, - minPt, - scale, - this->MinimumExtent, - dims, - this->MaximumValue, - temp2 }; - Algo::Schedule(worker, dims); - - return vtkm::cont::Field(name, vtkm::cont::Field::Association::POINTS, worker.Output); + vtkm::cont::ArrayHandle output; + vtkm::cont::TryExecuteOnDevice(device, + wavelet::runWorker{}, + dims, + output, + this->Center, + this->Spacing, + this->Frequency, + this->Magnitude, + minPt, + scale, + this->MinimumExtent, + dims, + this->MaximumValue, + temp2); + return vtkm::cont::Field(name, vtkm::cont::Field::Association::POINTS, output); } private: diff --git a/vtkm/worklet/testing/UnitTestWaveletGenerator.cxx b/vtkm/worklet/testing/UnitTestWaveletGenerator.cxx index caabb45f6..d30b6224d 100644 --- a/vtkm/worklet/testing/UnitTestWaveletGenerator.cxx +++ b/vtkm/worklet/testing/UnitTestWaveletGenerator.cxx @@ -13,28 +13,13 @@ #include #include -namespace detail -{ - -struct WaveletGeneratorFunctor -{ - vtkm::cont::DataSet Ds; - template - bool operator()(DeviceAdapter) - { - vtkm::worklet::WaveletGenerator gen; - Ds = gen.GenerateDataSet(); - return true; - } -}; -} - void WaveletGeneratorTest() { vtkm::cont::Timer timer; timer.Start(); - detail::WaveletGeneratorFunctor wgFunctor; - vtkm::cont::TryExecute(wgFunctor); + + vtkm::worklet::WaveletGenerator gen; + vtkm::cont::DataSet ds = gen.GenerateDataSet(); double time = timer.GetElapsedTime(); @@ -42,13 +27,13 @@ void WaveletGeneratorTest() std::cout << "Default wavelet took " << time << "s.\n"; { - auto coords = wgFunctor.Ds.GetCoordinateSystem("coords"); + auto coords = ds.GetCoordinateSystem("coords"); auto data = coords.GetData(); VTKM_TEST_ASSERT(test_equal(data.GetNumberOfValues(), 9261), "Incorrect number of points."); } { - auto cells = wgFunctor.Ds.GetCellSet(wgFunctor.Ds.GetCellSetIndex("cells")); + auto cells = ds.GetCellSet(ds.GetCellSetIndex("cells")); VTKM_TEST_ASSERT(test_equal(cells.GetNumberOfCells(), 8000), "Incorrect number of cells."); } @@ -56,7 +41,7 @@ void WaveletGeneratorTest() { using ScalarHandleType = vtkm::cont::ArrayHandle; - auto field = wgFunctor.Ds.GetField("scalars", vtkm::cont::Field::Association::POINTS); + auto field = ds.GetField("scalars", vtkm::cont::Field::Association::POINTS); auto dynData = field.GetData(); VTKM_TEST_ASSERT(dynData.IsType(), "Invalid scalar handle type."); ScalarHandleType handle = dynData.Cast();