diff --git a/vtkm/worklet/Wavelets.h b/vtkm/worklet/Wavelets.h index a9537d404..609d32176 100644 --- a/vtkm/worklet/Wavelets.h +++ b/vtkm/worklet/Wavelets.h @@ -28,19 +28,50 @@ namespace vtkm { namespace worklet { +namespace internal{ + + const vtkm::Float64 hm4_44[9] = { + 0.037828455507264, + -0.023849465019557, + -0.110624404418437, + 0.377402855612831, + 0.852698679008894, + 0.377402855612831, + -0.110624404418437, + -0.023849465019557, + 0.037828455507264 + }; + + const vtkm::Float64 h4[9] = { + 0.0, + -0.064538882628697, + -0.040689417609164, + 0.418092273221617, + 0.788485616405583, + 0.418092273221617, + -0.0406894176091641, + -0.0645388826286971, + 0.0 + }; +} + class Wavelets { public: + // helper worklet class ForwardTransform: public vtkm::worklet::WorkletMapField { public: - typedef void ControlSignature(WholeArrayIn, // sigIn - FieldOut, // cA - FieldOut); // cD - typedef void ExecutionSignature(_1, _2, _3, WorkIndex); + typedef void ControlSignature(WholeArrayIn, // sigIn + WholeArrayIn, // lowFilter + WholeArrayIn, // highFilter + FieldOut); // cA in even indices, cD in odd indices + typedef void ExecutionSignature(_1, _2, _3, _4, WorkIndex); typedef _1 InputDomain; + typedef vtkm::Float64 FLOAT; + // ForwardTransform constructor VTKM_CONT_EXPORT ForwardTransform() @@ -51,16 +82,25 @@ public: } - template + template VTKM_EXEC_EXPORT - void operator()(const ArrayPortalType &signalIn, - T &coeffApproximation, - T &coeffDetail, + void operator()(const InputSignalPortalType &signalIn, + const FilterPortalType &lowFilter, + const FilterPortalType &highFilter, + OutputCoeffType &coeffOut, const vtkm::Id &workIndex) const { - vtkm::Float64 tmp = static_cast(signalIn.Get(workIndex)); - coeffApproximation = static_cast( tmp / 2.0 ); - coeffDetail = static_cast( tmp * 2.0 ); + FLOAT tmp = static_cast(signalIn.Get(workIndex)); + if( workIndex % 2 == 0 ) // work on cA, approximate coeffs + { + coeffOut = static_cast( tmpi + lowFilter.Get(0) ); + } + else // work on cD, detail coeffs + { + coeffOut = static_cast( tmp + highFilter.Get(0) ); + } } private: diff --git a/vtkm/worklet/testing/UnitTestWavelets.cxx b/vtkm/worklet/testing/UnitTestWavelets.cxx index daa129859..85e6614e3 100644 --- a/vtkm/worklet/testing/UnitTestWavelets.cxx +++ b/vtkm/worklet/testing/UnitTestWavelets.cxx @@ -33,26 +33,28 @@ void TestWavelets() vtkm::Id arraySize = 10; std::vector tmpVector; for( vtkm::Id i = 0; i < arraySize; i++ ) - tmpVector.push_back(static_cast(i*2)); + tmpVector.push_back(static_cast(i)); vtkm::cont::ArrayHandle input1DArray = vtkm::cont::make_ArrayHandle(tmpVector); vtkm::cont::ArrayHandle outputArray1; +/* outputArray1.Allocate( arraySize ); vtkm::cont::ArrayHandle outputArray2; outputArray2.Allocate( arraySize ); +*/ vtkm::worklet::Wavelets::ForwardTransform forwardTransform; vtkm::worklet::DispatcherMapField dispatcher(forwardTransform); - dispatcher.Invoke(input1DArray, outputArray1, outputArray2); + dispatcher.Invoke(input1DArray, outputArray1); std::cerr << "Invoke succeeded" << std::endl; for (vtkm::Id i = 0; i < outputArray1.GetNumberOfValues(); ++i) { - std::cout<< outputArray1.GetPortalConstControl().Get(i) << ", " - << outputArray2.GetPortalConstControl().Get(i) << std::endl; + std::cout<< outputArray1.GetPortalConstControl().Get(i) << std::endl; +// << outputArray2.GetPortalConstControl().Get(i) << std::endl; // VTKM_TEST_ASSERT( // test_equal( output1DArray.GetPortalConstControl().Get(i), // static_cast(i) * 2.0f ),