From 6bfb1deb944683aa41c001a73f76b4d0f265ffb7 Mon Sep 17 00:00:00 2001 From: Samuel Li Date: Mon, 18 Jul 2016 17:45:17 -0600 Subject: [PATCH] DWT1D seems working OK now --- vtkm/filter/internal/WaveletDWT.h | 53 +++++++++++++------ .../UnitTestWaveletCompressorFilter.cxx | 15 ++++-- 2 files changed, 48 insertions(+), 20 deletions(-) diff --git a/vtkm/filter/internal/WaveletDWT.h b/vtkm/filter/internal/WaveletDWT.h index 363bb42d4..fb01a313d 100644 --- a/vtkm/filter/internal/WaveletDWT.h +++ b/vtkm/filter/internal/WaveletDWT.h @@ -29,6 +29,8 @@ #include #include +#include +#include #include @@ -39,6 +41,7 @@ namespace internal { class WaveletDWT : public WaveletBase { public: + typedef vtkm::cont::ArrayHandle< vtkm::Float64 > ArrayType64; // Constructor WaveletDWT( const std::string &w_name ) : WaveletBase( w_name ) {} @@ -123,9 +126,10 @@ public: // Performs one level of 1D discrete wavelet transform // It takes care of boundary conditions, etc. - template< typename SignalArrayType, typename CoeffArrayType > + template< typename SignalArrayType > vtkm::Id DWT1D( const SignalArrayType &sigIn, // Input - CoeffArrayType &sigOut, + ArrayType64 &cA, // Approximate Coefficients + ArrayType64 &cD, // Detail Coefficients vtkm::Id L[3] ) { @@ -169,15 +173,17 @@ public: vtkm::Id sigExtendedLen = sigInLen + 2 * addLen; typedef typename SignalArrayType::ValueType SigInValueType; - typedef vtkm::cont::ArrayHandle ArrayType; - typedef vtkm::cont::ArrayHandleConcatenate< ArrayType, ArrayType> + typedef vtkm::cont::ArrayHandle SigInArrayType; + typedef vtkm::cont::ArrayHandleConcatenate< SigInArrayType, SigInArrayType> ArrayConcat; - typedef vtkm::cont::ArrayHandleConcatenate< ArrayConcat, ArrayType > ArrayConcat2; + typedef vtkm::cont::ArrayHandleConcatenate< ArrayConcat, SigInArrayType > ArrayConcat2; ArrayConcat2 sigInExtended; this->Extend1D( sigIn, sigInExtended, addLen, this->wmode, this->wmode ); - ArrayType coeffOutTmp; + // Coefficients in coeffOutTmp are interleaving, + // e.g. cA are at 0, 2, 4...; cD are at 1, 3, 5... + ArrayType64 coeffOutTmp; // initialize a worklet @@ -186,10 +192,6 @@ public: forwardTransform.SetCoeffLength( L[0], L[1] ); forwardTransform.SetOddness( oddLow, oddHigh ); - // setup a timer - //srand ((unsigned int)time(NULL)); - //vtkm::cont::Timer<> timer; - vtkm::worklet::DispatcherMapField dispatcher(forwardTransform); dispatcher.Invoke( sigInExtended, @@ -197,20 +199,37 @@ public: filter->GetHighDecomposeFilter(), coeffOutTmp ); - //vtkm::Id randNum = rand() % sigLen; - //std::cout << "A random output: " - // << outputArray1.GetPortalConstControl().Get(randNum) << std::endl; + // Separate cA and cD. + typedef vtkm::cont::ArrayHandleCounting< vtkm::Id > IdArrayType; + typedef vtkm::cont::ArrayHandlePermutation< IdArrayType, ArrayType64 > PermutArrayType; - //vtkm::Float64 elapsedTime = timer.GetElapsedTime(); - //std::cerr << "Dealing array size " << sigLen/million << " millions takes time " - // << elapsedTime << std::endl; - if( sigInLen < 21 ) + IdArrayType approxIndices( 0, 2, L[0] ); + IdArrayType detailIndices( 1, 2, L[1] ); + PermutArrayType cATmp = vtkm::cont::make_ArrayHandlePermutation( + approxIndices, coeffOutTmp ); + PermutArrayType cDTmp = vtkm::cont::make_ArrayHandlePermutation( + detailIndices, coeffOutTmp ); + + vtkm::cont::DeviceAdapterAlgorithm< VTKM_DEFAULT_DEVICE_ADAPTER_TAG>::Copy( + cATmp, cA ); + vtkm::cont::DeviceAdapterAlgorithm< VTKM_DEFAULT_DEVICE_ADAPTER_TAG>::Copy( + cDTmp, cD ); + + /* + if( sigInLen < 41 ) + { + std::cout << "sigInExtended has length: " << sigInExtended.GetNumberOfValues() << std::endl; + std::cout << "coeffOutTmp has length: " << coeffOutTmp.GetNumberOfValues() << std::endl; + printf("L[3]: %lld, %lld, %lld\n", L[0], L[1], L[2]); for (vtkm::Id i = 0; i < coeffOutTmp.GetNumberOfValues(); ++i) { std::cout << coeffOutTmp.GetPortalConstControl().Get(i) << ", "; if( i % 2 != 0 ) std::cout << std::endl; } + } + */ + return 0; } diff --git a/vtkm/filter/testing/UnitTestWaveletCompressorFilter.cxx b/vtkm/filter/testing/UnitTestWaveletCompressorFilter.cxx index 77fc04cd8..7518aed7c 100644 --- a/vtkm/filter/testing/UnitTestWaveletCompressorFilter.cxx +++ b/vtkm/filter/testing/UnitTestWaveletCompressorFilter.cxx @@ -21,6 +21,8 @@ #include #include +#include + #include @@ -61,7 +63,7 @@ void TestDWT1D() std::cout << "Input a new size to test (in millions)." << std::endl; std::cout << "Input 0 to stick with 20." << std::endl; vtkm::Id tmpIn; - vtkm::Id million = 1000000; + vtkm::Id million = 1;//1000000; std::cin >> tmpIn; if( tmpIn != 0 ) sigLen = tmpIn * million; @@ -73,11 +75,18 @@ void TestDWT1D() vtkm::cont::ArrayHandle inputArray = vtkm::cont::make_ArrayHandle(tmpVector); - vtkm::cont::ArrayHandle outputArray; + vtkm::cont::ArrayHandle cA, cD; vtkm::Id L[3]; vtkm::filter::internal::WaveletDWT waveletdwt( "CDF9/7" ); - waveletdwt.DWT1D( inputArray, outputArray, L ); + waveletdwt.DWT1D( inputArray, cA, cD, L ); + + std::cout << "cA: length=" << cA.GetNumberOfValues() << std::endl; + for( vtkm::Id i; i < cA.GetNumberOfValues(); i++ ) + std::cout << cA.GetPortalConstControl().Get(i) << std::endl; + std::cout << "cD: length=" << cD.GetNumberOfValues() << std::endl; + for( vtkm::Id i; i < cD.GetNumberOfValues(); i++ ) + std::cout << cD.GetPortalConstControl().Get(i) << std::endl; }