DWT1D seems working OK now

This commit is contained in:
Samuel Li 2016-07-18 17:45:17 -06:00
parent 4b380a3d78
commit 6bfb1deb94
2 changed files with 48 additions and 20 deletions

@ -29,6 +29,8 @@
#include <vtkm/worklet/WaveletTransforms.h> #include <vtkm/worklet/WaveletTransforms.h>
#include <vtkm/cont/ArrayHandleConcatenate.h> #include <vtkm/cont/ArrayHandleConcatenate.h>
#include <vtkm/cont/ArrayHandleCounting.h>
#include <vtkm/cont/ArrayHandlePermutation.h>
#include <vtkm/Math.h> #include <vtkm/Math.h>
@ -39,6 +41,7 @@ namespace internal {
class WaveletDWT : public WaveletBase class WaveletDWT : public WaveletBase
{ {
public: public:
typedef vtkm::cont::ArrayHandle< vtkm::Float64 > ArrayType64;
// Constructor // Constructor
WaveletDWT( const std::string &w_name ) : WaveletBase( w_name ) {} WaveletDWT( const std::string &w_name ) : WaveletBase( w_name ) {}
@ -123,9 +126,10 @@ public:
// Performs one level of 1D discrete wavelet transform // Performs one level of 1D discrete wavelet transform
// It takes care of boundary conditions, etc. // It takes care of boundary conditions, etc.
template< typename SignalArrayType, typename CoeffArrayType > template< typename SignalArrayType >
vtkm::Id DWT1D( const SignalArrayType &sigIn, // Input vtkm::Id DWT1D( const SignalArrayType &sigIn, // Input
CoeffArrayType &sigOut, ArrayType64 &cA, // Approximate Coefficients
ArrayType64 &cD, // Detail Coefficients
vtkm::Id L[3] ) vtkm::Id L[3] )
{ {
@ -169,15 +173,17 @@ public:
vtkm::Id sigExtendedLen = sigInLen + 2 * addLen; vtkm::Id sigExtendedLen = sigInLen + 2 * addLen;
typedef typename SignalArrayType::ValueType SigInValueType; typedef typename SignalArrayType::ValueType SigInValueType;
typedef vtkm::cont::ArrayHandle<SigInValueType> ArrayType; typedef vtkm::cont::ArrayHandle<SigInValueType> SigInArrayType;
typedef vtkm::cont::ArrayHandleConcatenate< ArrayType, ArrayType> typedef vtkm::cont::ArrayHandleConcatenate< SigInArrayType, SigInArrayType>
ArrayConcat; ArrayConcat;
typedef vtkm::cont::ArrayHandleConcatenate< ArrayConcat, ArrayType > ArrayConcat2; typedef vtkm::cont::ArrayHandleConcatenate< ArrayConcat, SigInArrayType > ArrayConcat2;
ArrayConcat2 sigInExtended; ArrayConcat2 sigInExtended;
this->Extend1D( sigIn, sigInExtended, addLen, this->wmode, this->wmode ); this->Extend1D( sigIn, sigInExtended, addLen, this->wmode, this->wmode );
ArrayType coeffOutTmp; // Coefficients in coeffOutTmp are interleaving,
// e.g. cA are at 0, 2, 4...; cD are at 1, 3, 5...
ArrayType64 coeffOutTmp;
// initialize a worklet // initialize a worklet
@ -186,10 +192,6 @@ public:
forwardTransform.SetCoeffLength( L[0], L[1] ); forwardTransform.SetCoeffLength( L[0], L[1] );
forwardTransform.SetOddness( oddLow, oddHigh ); forwardTransform.SetOddness( oddLow, oddHigh );
// setup a timer
//srand ((unsigned int)time(NULL));
//vtkm::cont::Timer<> timer;
vtkm::worklet::DispatcherMapField<vtkm::worklet::ForwardTransform> vtkm::worklet::DispatcherMapField<vtkm::worklet::ForwardTransform>
dispatcher(forwardTransform); dispatcher(forwardTransform);
dispatcher.Invoke( sigInExtended, dispatcher.Invoke( sigInExtended,
@ -197,20 +199,37 @@ public:
filter->GetHighDecomposeFilter(), filter->GetHighDecomposeFilter(),
coeffOutTmp ); coeffOutTmp );
//vtkm::Id randNum = rand() % sigLen; // Separate cA and cD.
//std::cout << "A random output: " typedef vtkm::cont::ArrayHandleCounting< vtkm::Id > IdArrayType;
// << outputArray1.GetPortalConstControl().Get(randNum) << std::endl; typedef vtkm::cont::ArrayHandlePermutation< IdArrayType, ArrayType64 > PermutArrayType;
//vtkm::Float64 elapsedTime = timer.GetElapsedTime(); IdArrayType approxIndices( 0, 2, L[0] );
//std::cerr << "Dealing array size " << sigLen/million << " millions takes time " IdArrayType detailIndices( 1, 2, L[1] );
// << elapsedTime << std::endl; PermutArrayType cATmp = vtkm::cont::make_ArrayHandlePermutation(
if( sigInLen < 21 ) approxIndices, coeffOutTmp );
PermutArrayType cDTmp = vtkm::cont::make_ArrayHandlePermutation(
detailIndices, coeffOutTmp );
vtkm::cont::DeviceAdapterAlgorithm< VTKM_DEFAULT_DEVICE_ADAPTER_TAG>::Copy(
cATmp, cA );
vtkm::cont::DeviceAdapterAlgorithm< VTKM_DEFAULT_DEVICE_ADAPTER_TAG>::Copy(
cDTmp, cD );
/*
if( sigInLen < 41 )
{
std::cout << "sigInExtended has length: " << sigInExtended.GetNumberOfValues() << std::endl;
std::cout << "coeffOutTmp has length: " << coeffOutTmp.GetNumberOfValues() << std::endl;
printf("L[3]: %lld, %lld, %lld\n", L[0], L[1], L[2]);
for (vtkm::Id i = 0; i < coeffOutTmp.GetNumberOfValues(); ++i) for (vtkm::Id i = 0; i < coeffOutTmp.GetNumberOfValues(); ++i)
{ {
std::cout << coeffOutTmp.GetPortalConstControl().Get(i) << ", "; std::cout << coeffOutTmp.GetPortalConstControl().Get(i) << ", ";
if( i % 2 != 0 ) if( i % 2 != 0 )
std::cout << std::endl; std::cout << std::endl;
} }
}
*/
return 0; return 0;
} }

@ -21,6 +21,8 @@
#include <vtkm/filter/internal/WaveletDWT.h> #include <vtkm/filter/internal/WaveletDWT.h>
#include <vtkm/cont/testing/Testing.h> #include <vtkm/cont/testing/Testing.h>
#include <vtkm/cont/ArrayHandlePermutation.h>
#include <vector> #include <vector>
@ -61,7 +63,7 @@ void TestDWT1D()
std::cout << "Input a new size to test (in millions)." << std::endl; std::cout << "Input a new size to test (in millions)." << std::endl;
std::cout << "Input 0 to stick with 20." << std::endl; std::cout << "Input 0 to stick with 20." << std::endl;
vtkm::Id tmpIn; vtkm::Id tmpIn;
vtkm::Id million = 1000000; vtkm::Id million = 1;//1000000;
std::cin >> tmpIn; std::cin >> tmpIn;
if( tmpIn != 0 ) if( tmpIn != 0 )
sigLen = tmpIn * million; sigLen = tmpIn * million;
@ -73,11 +75,18 @@ void TestDWT1D()
vtkm::cont::ArrayHandle<vtkm::Float64> inputArray = vtkm::cont::ArrayHandle<vtkm::Float64> inputArray =
vtkm::cont::make_ArrayHandle(tmpVector); vtkm::cont::make_ArrayHandle(tmpVector);
vtkm::cont::ArrayHandle<vtkm::Float64> outputArray; vtkm::cont::ArrayHandle<vtkm::Float64> cA, cD;
vtkm::Id L[3]; vtkm::Id L[3];
vtkm::filter::internal::WaveletDWT waveletdwt( "CDF9/7" ); vtkm::filter::internal::WaveletDWT waveletdwt( "CDF9/7" );
waveletdwt.DWT1D( inputArray, outputArray, L ); waveletdwt.DWT1D( inputArray, cA, cD, L );
std::cout << "cA: length=" << cA.GetNumberOfValues() << std::endl;
for( vtkm::Id i; i < cA.GetNumberOfValues(); i++ )
std::cout << cA.GetPortalConstControl().Get(i) << std::endl;
std::cout << "cD: length=" << cD.GetNumberOfValues() << std::endl;
for( vtkm::Id i; i < cD.GetNumberOfValues(); i++ )
std::cout << cD.GetPortalConstControl().Get(i) << std::endl;
} }