forward transform now avoids data copy when performing multi-level transforms

This commit is contained in:
Samuel Li 2016-10-05 15:52:07 -07:00
parent edd74febd8
commit 7eb494895f
4 changed files with 519 additions and 60 deletions

@ -195,27 +195,28 @@ public:
typedef typename OutArrayType::ValueType OutValueType;
typedef vtkm::cont::ArrayHandle<OutValueType> OutBasicArray;
//vtkm::cont::DeviceAdapterAlgorithm< DeviceTag >::Copy( sigIn, coeffOut );
// First level transform operates on the input array
computationTime +=
WaveletDWT::DWT2Dv3( sigIn, currentLenX, currentLenY, coeffOut, L2d, DeviceTag());
// First level transform operates writes to the output array
computationTime += WaveletDWT::DWT2Dv3( sigIn,
currentLenX, currentLenY,
0, 0,
currentLenX, currentLenY,
coeffOut, L2d, DeviceTag() );
VTKM_ASSERT( coeffOut.GetNumberOfValues() == currentLenX * currentLenY );
currentLenX = WaveletBase::GetApproxLength( currentLenX );
currentLenY = WaveletBase::GetApproxLength( currentLenY );
// Successor transforms operate on a temporary array
// Successor transforms writes to a temporary array
for( vtkm::Id i = nLevels-1; i > 0; i-- )
{
// make temporary input array
OutBasicArray tempInput;
WaveletBase::DeviceRectangleCopyFrom( tempInput, currentLenX, currentLenY,
coeffOut, inX, inY, 0, 0, DeviceTag() );
//make temporary output array
OutBasicArray tempOutput;
computationTime +=
WaveletDWT::DWT2Dv3( tempInput, currentLenX, currentLenY, tempOutput, L2d, DeviceTag());
WaveletDWT::DWT2Dv3( coeffOut,
inX, inY,
0, 0,
currentLenX, currentLenY,
tempOutput, L2d, DeviceTag() );
// copy results to coeffOut
WaveletBase::DeviceRectangleCopyTo( tempOutput, currentLenX, currentLenY,

@ -126,8 +126,8 @@ void FillArray2D( ArrayType& array, vtkm::Id dimX, vtkm::Id dimY )
void DebugDWTIDWT2D()
{
vtkm::Id NX = 4;
vtkm::Id NY = 4;
vtkm::Id NX = 10;
vtkm::Id NY = 11;
typedef vtkm::cont::ArrayHandle< vtkm::Float64 > ArrayType;
ArrayType left, center, right;
@ -138,13 +138,15 @@ void DebugDWTIDWT2D()
ArrayType output1, output2, output3;
std::vector<vtkm::Id> L(10, 0);
vtkm::worklet::wavelets::WaveletDWT dwt( vtkm::worklet::wavelets::HAAR );
vtkm::worklet::wavelets::WaveletDWT dwt( vtkm::worklet::wavelets::CDF9_7 );
// get true results
dwt.DWT2Dv2(center, NX, NY, output1, L, VTKM_DEFAULT_DEVICE_ADAPTER_TAG());
std::cerr << "...before DWT2Dv3..." << std::endl;
// get test results
dwt.DWT2Dv3( center, NX, NY, output3, L, VTKM_DEFAULT_DEVICE_ADAPTER_TAG() );
dwt.DWT2Dv3( center, NX, NY, 0, 0, NX, NY, output3, L, VTKM_DEFAULT_DEVICE_ADAPTER_TAG() );
std::cerr << "...finish DWT2Dv3..." << std::endl;
for( vtkm::Id i = 0; i < output1.GetNumberOfValues(); i++ )
{
@ -153,10 +155,10 @@ void DebugDWTIDWT2D()
"WaveletCompressor worklet failed..." );
}
dwt.Print2DArray("\ntrue results after 2D DWT:", output1, NX );
dwt.Print2DArray("\ntest results after 2D DWT:", output3, NX );
// dwt.Print2DArray("\ntrue results after 2D DWT:", output1, NX );
// dwt.Print2DArray("\ntest results after 2D DWT:", output3, NX );
ArrayType idwt_out1, idwt_out2;
ArrayType idwt_out1, idwt_out2;
// true results go through IDWT
dwt.IDWT2Dv2( output1, L, idwt_out1, VTKM_DEFAULT_DEVICE_ADAPTER_TAG() );
@ -231,7 +233,7 @@ void TestDecomposeReconstruct2D()
{
std::cout << "Testing a 1024x1024 square: " << std::endl;
vtkm::Id sigX = 1024;
vtkm::Id sigY = 1024;
vtkm::Id sigY = 1024;
//std::cout << "Please input X to test a X^2 square: " << std::endl;
//std::cin >> sigX;
//sigY = sigX;

@ -47,6 +47,172 @@ public:
typedef vtkm::Float64 FLOAT_64;
template< typename SigInArrayType, typename ExtensionArrayType, typename DeviceTag >
vtkm::Id Extend2Dv3(const SigInArrayType &sigIn, // Input
vtkm::Id sigDimX,
vtkm::Id sigDimY,
vtkm::Id sigStartX,
vtkm::Id sigStartY,
vtkm::Id sigPretendDimX,
vtkm::Id sigPretendDimY,
ExtensionArrayType &ext1, // left/top extension
ExtensionArrayType &ext2, // right/bottom extension
vtkm::Id addLen,
vtkm::worklet::wavelets::DWTMode ext1Method,
vtkm::worklet::wavelets::DWTMode ext2Method,
bool pretendSigPaddedZero,
bool padZeroAtExt2,
bool modeLR, // true = left-right
// false = top-down
DeviceTag )
{
// pretendSigPaddedZero and padZeroAtExt2 cannot happen at the same time
VTKM_ASSERT( !pretendSigPaddedZero || !padZeroAtExt2 );
if( addLen == 0 ) // Haar kernel
{
ext1.PrepareForOutput( 0, DeviceTag() );
if( pretendSigPaddedZero || padZeroAtExt2 )
{
if( modeLR ) // right extension
{
ext2.PrepareForOutput( sigPretendDimY, DeviceTag() );
WaveletBase::DeviceAssignZero2DColumn( ext2, 1, sigPretendDimY, 0, DeviceTag() );
}
else // bottom extension
{
ext2.PrepareForOutput( sigPretendDimX, DeviceTag() );
WaveletBase::DeviceAssignZero2DRow( ext2, sigPretendDimX, 1, 0, DeviceTag() );
}
}
else
ext2.PrepareForOutput( 0, DeviceTag() );
return 0;
}
typedef typename SigInArrayType::ValueType ValueType;
typedef vtkm::cont::ArrayHandle< ValueType > ExtendArrayType;
typedef vtkm::worklet::wavelets::ExtensionWorklet2Dv3 ExtensionWorklet;
typedef typename vtkm::worklet::DispatcherMapField< ExtensionWorklet, DeviceTag >
DispatcherType;
vtkm::Id extDimX, extDimY;
vtkm::worklet::wavelets::ExtensionDirection2D dir;
// Work on left/top extension
{
if( modeLR )
{
dir = LEFT;
extDimX = addLen;
extDimY = sigPretendDimY;
}
else
{
dir = TOP;
extDimX = sigPretendDimX;
extDimY = addLen;
}
ext1.PrepareForOutput( extDimX * extDimY, DeviceTag() );
ExtensionWorklet worklet( extDimX, extDimY, sigDimX, sigDimY,
sigStartX, sigStartY, sigPretendDimX, sigPretendDimY, // use this area
ext1Method, dir, false ); // not treating sigIn as having zeros
DispatcherType dispatcher( worklet );
dispatcher.Invoke( ext1, sigIn );
}
// Work on right/bottom extension
if( !pretendSigPaddedZero && !padZeroAtExt2 )
{
if( modeLR )
{
dir = RIGHT;
extDimX = addLen;
extDimY = sigPretendDimY;
}
else
{
dir = BOTTOM;
extDimX = sigPretendDimX;
extDimY = addLen;
}
ext2.PrepareForOutput( extDimX * extDimY, DeviceTag() );
ExtensionWorklet worklet( extDimX, extDimY, sigDimX, sigDimY,
sigStartX, sigStartY, sigPretendDimX, sigPretendDimY, // use this area
ext2Method, dir, false );
DispatcherType dispatcher( worklet );
dispatcher.Invoke( ext2, sigIn );
}
else if( !pretendSigPaddedZero && padZeroAtExt2 )
{
if( modeLR )
{
dir = RIGHT;
extDimX = addLen+1;
extDimY = sigPretendDimY;
}
else
{
dir = BOTTOM;
extDimX = sigPretendDimX;
extDimY = addLen+1;
}
ext2.PrepareForOutput( extDimX * extDimY, DeviceTag() );
ExtensionWorklet worklet( extDimX, extDimY, sigDimX, sigDimY,
sigStartX, sigStartY, sigPretendDimX, sigPretendDimY,
ext2Method, dir, false );
DispatcherType dispatcher( worklet );
dispatcher.Invoke( ext2, sigIn );
if( modeLR )
WaveletBase::DeviceAssignZero2DColumn( ext2, extDimX, extDimY,
extDimX-1, DeviceTag() );
else
WaveletBase::DeviceAssignZero2DRow( ext2, extDimX, extDimY,
extDimY-1, DeviceTag() );
}
else // pretendSigPaddedZero
{
ExtendArrayType ext2Temp;
if( modeLR )
{
dir = RIGHT;
extDimX = addLen;
extDimY = sigPretendDimY;
}
else
{
dir = BOTTOM;
extDimX = sigPretendDimX;
extDimY = addLen;
}
ext2Temp.PrepareForOutput( extDimX * extDimY, DeviceTag() );
ExtensionWorklet worklet( extDimX, extDimY, sigDimX, sigDimY,
sigStartX, sigStartY, sigPretendDimX, sigPretendDimY,
ext2Method, dir, true ); // pretend sig is padded a zero
DispatcherType dispatcher( worklet );
dispatcher.Invoke( ext2Temp, sigIn );
if( modeLR )
{
ext2.PrepareForOutput( (extDimX+1) * extDimY, DeviceTag() );
WaveletBase::DeviceRectangleCopyTo( ext2Temp, extDimX, extDimY,
ext2, extDimX+1, extDimY,
1, 0, DeviceTag() );
WaveletBase::DeviceAssignZero2DColumn( ext2, extDimX+1, extDimY,
0, DeviceTag() );
}
else
{
ext2.PrepareForOutput( extDimX * (extDimY+1), DeviceTag() );
WaveletBase::DeviceRectangleCopyTo( ext2Temp, extDimX, extDimY,
ext2, extDimX, extDimY+1,
0, 1, DeviceTag() );
WaveletBase::DeviceAssignZero2DRow( ext2, extDimX, extDimY+1,
0, DeviceTag() );
}
}
return 0;
}
#if 0
template< typename SigInArrayType, typename ExtensionArrayType, typename DeviceTag >
vtkm::Id Extend2Dv3(const SigInArrayType &sigIn, // Input
vtkm::Id sigDimX,
@ -203,7 +369,7 @@ public:
}
return 0;
}
#endif
// Func: Extend 1D signal
template< typename SigInArrayType, typename SigExtendedArrayType, typename DeviceTag >
@ -842,7 +1008,101 @@ if( print)
}
// Performs one level of 2D discrete wavelet transform
// Performs one level of 2D discrete wavelet transform on a small rectangle of input array
// The output has the same size as the small rectangle
template< typename ArrayInType, typename ArrayOutType, typename DeviceTag >
FLOAT_64 DWT2Dv3( const ArrayInType &sigIn,
vtkm::Id sigDimX,
vtkm::Id sigDimY,
vtkm::Id sigStartX,
vtkm::Id sigStartY,
vtkm::Id sigPretendDimX,
vtkm::Id sigPretendDimY,
ArrayOutType &coeffOut,
std::vector<vtkm::Id> &L,
DeviceTag )
{
VTKM_ASSERT( sigDimX * sigDimY == sigIn.GetNumberOfValues() );
VTKM_ASSERT( L.size() == 10 );
L[0] = WaveletBase::GetApproxLength( sigPretendDimX ); L[2] = L[0];
L[1] = WaveletBase::GetApproxLength( sigPretendDimY ); L[5] = L[1];
L[3] = WaveletBase::GetDetailLength( sigPretendDimY ); L[7] = L[3];
L[4] = WaveletBase::GetDetailLength( sigPretendDimX ); L[6] = L[4];
L[8] = sigPretendDimX;
L[9] = sigPretendDimY;
vtkm::Id filterLen = WaveletBase::filter.GetFilterLength();
bool oddLow = true;
if( filterLen % 2 != 0 )
oddLow = false;
vtkm::Id addLen = filterLen / 2;
typedef typename ArrayInType::ValueType ValueType;
typedef vtkm::cont::ArrayHandle<ValueType> ArrayType;
typedef vtkm::worklet::wavelets::ForwardTransform2Dv3<DeviceTag> ForwardXFormv3;
typedef vtkm::worklet::wavelets::ForwardTransform2D<DeviceTag> ForwardXForm;
typedef typename vtkm::worklet::DispatcherMapField< ForwardXFormv3, DeviceTag >
DispatcherType;
vtkm::cont::Timer<DeviceTag> timer;
vtkm::Float64 computationTime = 0.0;
ArrayType afterX;
afterX.PrepareForOutput( sigPretendDimX * sigPretendDimY, DeviceTag() );
// First transform on rows
{
ArrayType leftExt, rightExt;
this->Extend2Dv3( sigIn,
sigDimX, sigDimY,
sigStartX, sigStartY,
sigPretendDimX, sigPretendDimY,
leftExt, rightExt, addLen,
WaveletBase::wmode, WaveletBase::wmode, false, false,
true, DeviceTag() ); // Extend in left-right direction
timer.Reset();
ForwardXFormv3 worklet( WaveletBase::filter.GetLowDecomposeFilter(),
WaveletBase::filter.GetHighDecomposeFilter(),
filterLen, L[0], oddLow, true, // left-right
addLen, sigPretendDimY,
sigDimX, sigDimY,
sigStartX, sigStartY,
sigPretendDimX, sigPretendDimY,
addLen, sigPretendDimY );
DispatcherType dispatcher(worklet);
dispatcher.Invoke( leftExt, sigIn, rightExt, afterX );
computationTime += timer.GetElapsedTime();
}
// Then do transform in Y direction
{
ArrayType topExt, bottomExt;
coeffOut.PrepareForOutput( sigPretendDimX * sigPretendDimY, DeviceTag() );
this->Extend2Dv3( afterX,
sigPretendDimX, sigPretendDimY,
0, 0,
sigPretendDimX, sigPretendDimY,
topExt, bottomExt, addLen,
WaveletBase::wmode, WaveletBase::wmode, false, false,
false, DeviceTag() ); // Extend in top-down direction
timer.Reset();
ForwardXFormv3 worklet( WaveletBase::filter.GetLowDecomposeFilter(),
WaveletBase::filter.GetHighDecomposeFilter(),
filterLen, L[1], oddLow, false, // top-down
sigPretendDimX, addLen,
sigPretendDimX, sigPretendDimY,
0, 0,
sigPretendDimX, sigPretendDimY,
sigPretendDimX, addLen );
DispatcherType dispatcher( worklet );
dispatcher.Invoke( topExt, afterX, bottomExt, coeffOut );
computationTime += timer.GetElapsedTime();
}
return computationTime;
}
#if 0
template< typename ArrayInType, typename ArrayOutType, typename DeviceTag >
FLOAT_64 DWT2Dv3( const ArrayInType &sigIn,
vtkm::Id sigDimX,
@ -917,6 +1177,7 @@ if( print)
}
return computationTime;
}
#endif
template< typename ArrayInType, typename ArrayOutType, typename DeviceTag >
@ -1057,7 +1318,7 @@ if( print)
coeffIn, inDimX, inDimY,
0, 0, DeviceTag() );
// extend cA
this->Extend2Dv3( cA, cADimX, cADimY, ext1, ext2, addLen,
this->Extend2Dv3( cA, cADimX, cADimY, 0, 0, cADimX, cADimY, ext1, ext2, addLen,
cALeft, cARight, false, false, true, DeviceTag() );
cA.ReleaseResources();
ext1DimX = ext2DimX = addLen;
@ -1072,7 +1333,7 @@ if( print)
// extend cD
if( cDPadLen > 0 )
{
this->Extend2Dv3( cD, cDDimX, cDDimY, ext3, ext4, addLen,
this->Extend2Dv3( cD, cDDimX, cDDimY, 0, 0, cDDimX, cDDimY, ext3, ext4, addLen,
cDLeft, cDRight, true, false, true, DeviceTag() );
ext3DimX = addLen;
ext4DimX = addLen + 1;
@ -1082,13 +1343,13 @@ if( print)
vtkm::Id cDExtendedWouldBe = cDDimX + 2 * addLen;
if( cDExtendedWouldBe == cDExtendedDimX )
{
this->Extend2Dv3( cD, cDDimX, cDDimY, ext3, ext4, addLen,
this->Extend2Dv3( cD, cDDimX, cDDimY, 0, 0, cDDimX, cDDimY, ext3, ext4, addLen,
cDLeft, cDRight, false, false, true, DeviceTag());
ext3DimX = ext4DimX = addLen;
}
else if( cDExtendedWouldBe == cDExtendedDimX - 1 )
{
this->Extend2Dv3( cD, cDDimX, cDDimY, ext3, ext4, addLen,
this->Extend2Dv3( cD, cDDimX, cDDimY, 0, 0, cDDimX, cDDimY, ext3, ext4, addLen,
cDLeft, cDRight, false, true, true, DeviceTag());
ext3DimX = addLen;
ext4DimX = addLen + 1;
@ -1164,7 +1425,7 @@ if( print)
coeffIn, inDimX, inDimY,
0, 0, DeviceTag() );
// extend cA
this->Extend2Dv3( cA, cADimX, cADimY, ext1, ext2, addLen,
this->Extend2Dv3( cA, cADimX, cADimY, 0, 0, cADimX, cADimY, ext1, ext2, addLen,
cATop, cABottom, false, false, false, DeviceTag() );
cA.ReleaseResources();
ext1DimY = ext2DimY = addLen;
@ -1179,7 +1440,7 @@ if( print)
// extend cD
if( cDPadLen > 0 )
{
this->Extend2Dv3( cD, cDDimX, cDDimY, ext3, ext4, addLen,
this->Extend2Dv3( cD, cDDimX, cDDimY, 0, 0, cDDimX, cDDimY, ext3, ext4, addLen,
cDTop, cDBottom, true, false, false, DeviceTag() );
ext3DimY = addLen;
ext4DimY = addLen + 1;
@ -1189,13 +1450,13 @@ if( print)
vtkm::Id cDExtendedWouldBe = cDDimY + 2 * addLen;
if( cDExtendedWouldBe == cDExtendedDimY )
{
this->Extend2Dv3( cD, cDDimX, cDDimY, ext3, ext4, addLen,
this->Extend2Dv3( cD, cDDimX, cDDimY, 0, 0, cDDimX, cDDimY, ext3, ext4, addLen,
cDTop, cDBottom, false, false, false, DeviceTag());
ext3DimY = ext4DimY = addLen;
}
else if( cDExtendedWouldBe == cDExtendedDimY - 1 )
{
this->Extend2Dv3( cD, cDDimX, cDDimY, ext3, ext4, addLen,
this->Extend2Dv3( cD, cDDimX, cDDimY, 0, 0, cDDimX, cDDimY, ext3, ext4, addLen,
cDTop, cDBottom, false, true, false, DeviceTag());
ext3DimY = addLen;
ext4DimY = addLen + 1;

@ -170,15 +170,38 @@ private:
};
// ................
// . .
// -----.--------------.-----
// | . | | . |
// | . | | . |
// | ext1 | mat2 | ext2 |
// | (x1) | (x2) | (x3) |
// | . | | . |
// -----.--------------.-----
// ................
class IndexTranslator3Matrices
{
public:
IndexTranslator3Matrices( vtkm::Id x_1, vtkm::Id y_1,
vtkm::Id x_2, vtkm::Id y_2,
IndexTranslator3Matrices( vtkm::Id x_1, vtkm::Id y_1,
vtkm::Id x_2, vtkm::Id y_2,
vtkm::Id startx_2, vtkm::Id starty_2,
vtkm::Id pretendx_2, vtkm::Id pretendy_2,
vtkm::Id x_3, vtkm::Id y_3, bool mode )
: x1(x_1), y1(y_1),
x2(x_2), y2(y_2),
x3(x_3), y3(y_3), mode_lr(mode) {}
:
dimX1(x_1), dimY1(y_1),
dimX2(x_2), dimY2(y_2), // real dimension of 2nd matrix
startX2( startx_2 ), startY2( starty_2 ),
pretendDimX2( pretendx_2 ), pretendDimY2( pretendy_2 ),
dimX3(x_3), dimY3(y_3),
mode_lr(mode)
{
/* printf("ext1 dims : %lld, %lld\n", dimX1, dimY1);
printf("signal dims : %lld, %lld\n", dimX2, dimY2);
printf("signal start : %lld, %lld\n", startX2, startY2);
printf("signal pretend dimx : %lld, %lld\n", pretendDimX2, pretendDimY2);
printf("ext2 dims : %lld, %lld\n", dimX3, dimY3); */
}
VTKM_EXEC_CONT_EXPORT
void Translate2Dto1D( vtkm::Id inX, vtkm::Id inY, // 2D indices as input
@ -186,40 +209,40 @@ public:
{
if( mode_lr ) // left-right mode
{
if ( 0 <= inX && inX < x1 )
if ( 0 <= inX && inX < dimX1 )
{
mat = 1;
idx = inY * x1 + inX;
idx = inY * dimX1 + inX;
}
else if ( x1 <= inX && inX < (x1 + x2) )
else if ( dimX1 <= inX && inX < (dimX1 + pretendDimX2) )
{
mat = 2;
idx = inY * x2 + (inX - x1);
idx = (inY + startY2) * dimX2 + (inX + startX2 - dimX1);
}
else if ( (x1 + x2) <= inX && inX < (x1 + x2 + x3) )
else if ( (dimX1 + pretendDimX2) <= inX && inX < (dimX1 + pretendDimX2 + dimX3) )
{
mat = 3;
idx = inY * x3 + (inX - x1 - x2);
idx = inY * dimX3 + (inX - dimX1 - pretendDimX2);
}
else
vtkm::cont::ErrorControlInternal("Invalid index!");
}
else // top-down mode
{
if ( 0 <= inY && inY < y1 )
if ( 0 <= inY && inY < dimY1 )
{
mat = 1;
idx = inY * x1 + inX;
idx = inY * dimX1 + inX;
}
else if ( y1 <= inY && inY < (y1 + y2) )
else if ( dimY1 <= inY && inY < (dimY1 + pretendDimY2) )
{
mat = 2;
idx = (inY - y1) * x1 + inX;
idx = (inY + startY2 - dimY1) * dimX2 + inX + startX2;
}
else if ( (y1 + y2) <= inY && inY < (y1 + y2 + y3) )
else if ( (dimY1 + pretendDimY2) <= inY && inY < (dimY1 + pretendDimY2 + dimY3) )
{
mat = 3;
idx = (inY - y1 - y2) * x1 + inX;
idx = (inY - dimY1 - pretendDimY2) * dimX3 + inX;
}
else
vtkm::cont::ErrorControlInternal("Invalid index!");
@ -227,7 +250,9 @@ public:
}
private:
const vtkm::Id x1, y1, x2, y2, x3, y3;
const vtkm::Id dimX1, dimY1;
const vtkm::Id dimX2, dimY2, startX2, startY2, pretendDimX2, pretendDimY2;
const vtkm::Id dimX3, dimY3;
const bool mode_lr; // true: left right mode; false: top down mode.
};
@ -245,6 +270,176 @@ public:
typedef _4 InputDomain;
// Constructor
VTKM_EXEC_CONT_EXPORT
ForwardTransform2Dv3( const vtkm::cont::ArrayHandle<vtkm::Float64> &loFilter,
const vtkm::cont::ArrayHandle<vtkm::Float64> &hiFilter,
vtkm::Id filter_len, vtkm::Id approx_len,
bool odd_low, bool mode_lr,
vtkm::Id x1, vtkm::Id y1, // dims of left/top extension
vtkm::Id x2, vtkm::Id y2, // dims of signal
vtkm::Id startx2, vtkm::Id starty2, // start idx of signal
vtkm::Id pretendx2, vtkm::Id pretendy2, // pretend dims of signal
vtkm::Id x3, vtkm::Id y3 ) // dims of right/bottom extension
:
lowFilter( loFilter.PrepareForInput( DeviceTag() ) ),
highFilter( hiFilter.PrepareForInput( DeviceTag() ) ),
filterLen( filter_len ), approxLen( approx_len ),
outDimX( pretendx2 ), outDimY( pretendy2 ),
oddlow( odd_low ), modeLR( mode_lr ),
translator( x1, y1,
x2, y2,
startx2, starty2,
pretendx2, pretendy2,
x3, y3,
mode_lr )
{ this->SetStartPosition(); }
VTKM_EXEC_CONT_EXPORT
void Output1Dto2D( vtkm::Id idx, vtkm::Id &x, vtkm::Id &y ) const
{
x = idx % outDimX;
y = idx / outDimX;
}
VTKM_EXEC_CONT_EXPORT
vtkm::Id Output2Dto1D( vtkm::Id x, vtkm::Id y ) const
{
return y * outDimX + x;
}
// Use 64-bit float for convolution calculation
#define VAL vtkm::Float64
#define MAKEVAL(a) (static_cast<VAL>(a))
template <typename InPortalType1, typename InPortalType2, typename InPortalType3 >
VTKM_EXEC_CONT_EXPORT
VAL GetVal( const InPortalType1 &portal1, const InPortalType2 &portal2,
const InPortalType3 &portal3, vtkm::Id inMatrix, vtkm::Id inIdx ) const
{
if( inMatrix == 1 )
return MAKEVAL( portal1.Get(inIdx) );
else if( inMatrix == 2 )
return MAKEVAL( portal2.Get(inIdx) );
else if( inMatrix == 3 )
return MAKEVAL( portal3.Get(inIdx) );
else
{
vtkm::cont::ErrorControlInternal("Invalid matrix index!");
return -1;
}
}
template <typename InPortalType1, typename InPortalType2,
typename InPortalType3, typename OutputPortalType>
VTKM_EXEC_CONT_EXPORT
void operator()(const InPortalType1 &inPortal1, // left/top extension
const InPortalType2 &inPortal2, // signal
const InPortalType3 &inPortal3, // right/bottom extension
OutputPortalType &coeffOut,
const vtkm::Id &workIndex) const
{
vtkm::Id workX, workY, output1D;
Output1Dto2D( workIndex, workX, workY );
vtkm::Id inputMatrix, inputIdx;
typedef typename OutputPortalType::ValueType OutputValueType;
if( modeLR )
{
if( workX % 2 == 0 ) // calculate cA
{
vtkm::Id xl = lstart + workX;
VAL sum = MAKEVAL(0.0);
for( vtkm::Id k = filterLen - 1; k > -1; k-- )
{
translator.Translate2Dto1D( xl, workY, inputMatrix, inputIdx );
sum += lowFilter.Get(k) *
GetVal( inPortal1, inPortal2, inPortal3, inputMatrix, inputIdx );
xl++;
}
output1D = Output2Dto1D( workX/2, workY );
coeffOut.Set( output1D, static_cast<OutputValueType>(sum) );
}
else // calculate cD
{
vtkm::Id xh = hstart + workX - 1;
VAL sum=MAKEVAL(0.0);
for( vtkm::Id k = filterLen - 1; k > -1; k-- )
{
translator.Translate2Dto1D( xh, workY, inputMatrix, inputIdx );
sum += highFilter.Get(k) *
GetVal( inPortal1, inPortal2, inPortal3, inputMatrix, inputIdx );
xh++;
}
output1D = Output2Dto1D( (workX-1)/2 + approxLen, workY );
coeffOut.Set( output1D, static_cast<OutputValueType>(sum) );
}
}
else // top-down order
{
if( workY % 2 == 0 ) // calculate cA
{
vtkm::Id yl = lstart + workY;
VAL sum = MAKEVAL(0.0);
for( vtkm::Id k = filterLen - 1; k > -1; k-- )
{
translator.Translate2Dto1D( workX, yl, inputMatrix, inputIdx );
sum += lowFilter.Get(k) *
GetVal( inPortal1, inPortal2, inPortal3, inputMatrix, inputIdx );
yl++;
}
output1D = Output2Dto1D( workX, workY/2 );
coeffOut.Set( output1D, static_cast<OutputValueType>(sum) );
}
else // calculate cD
{
vtkm::Id yh = hstart + workY - 1;
VAL sum=MAKEVAL(0.0);
for( vtkm::Id k = filterLen - 1; k > -1; k-- )
{
translator.Translate2Dto1D( workX, yh, inputMatrix, inputIdx );
sum += highFilter.Get(k) *
GetVal( inPortal1, inPortal2, inPortal3, inputMatrix, inputIdx );
yh++;
}
output1D = Output2Dto1D( workX, (workY-1)/2 + approxLen );
coeffOut.Set( output1D, static_cast<OutputValueType>(sum) );
}
}
}
#undef MAKEVAL
#undef VAL
private:
const typename vtkm::cont::ArrayHandle<vtkm::Float64>::ExecutionTypes<DeviceTag>::
PortalConst lowFilter, highFilter;
const vtkm::Id filterLen, approxLen;
const vtkm::Id outDimX, outDimY;
bool oddlow;
bool modeLR; // true = left right; false = top down.
const IndexTranslator3Matrices translator;
vtkm::Id lstart, hstart;
VTKM_EXEC_CONT_EXPORT
void SetStartPosition()
{
this->lstart = this->oddlow ? 1 : 0;
this->hstart = 1;
}
};
#if 0
template< typename DeviceTag >
class ForwardTransform2Dv3: public vtkm::worklet::WorkletMapField
{
public:
typedef void ControlSignature(WholeArrayIn<ScalarAll>, // left/top extension
WholeArrayIn<ScalarAll>, // sigIn
WholeArrayIn<ScalarAll>, // right/bottom extension
WholeArrayOut<ScalarAll>); // cA followed by cD
typedef void ExecutionSignature(_1, _2, _3, _4, WorkIndex);
typedef _4 InputDomain;
// Constructor
VTKM_EXEC_CONT_EXPORT
ForwardTransform2Dv3( const vtkm::cont::ArrayHandle<vtkm::Float64> &loFilter,
@ -393,11 +588,11 @@ private:
this->hstart = 1;
}
};
#endif
// Worklet for 2D signal extension
// This implementation operates on a small rectangle. Use it after LANL.
#if 0
// This implementation operates on a small rectangle.
class ExtensionWorklet2Dv3 : public vtkm::worklet::WorkletMapField
{
public:
@ -408,15 +603,16 @@ public:
// Constructor
VTKM_EXEC_CONT_EXPORT
ExtensionWorklet2Dv3( vtkm::Id x1, vtkm::Id y1,
vtkm::Id x2, vtkm::Id y2,
vtkm::Id x3, vtkm::Id y3,
vtkm::Id x4, vtkm::Id y4,
ExtensionWorklet2Dv3( vtkm::Id extdimX, vtkm::Id extdimY,
vtkm::Id sigdimX, vtkm::Id sigdimY,
vtkm::Id sigstartX, vtkm::Id sigstartY,
vtkm::Id sigpretendX, vtkm::Id sigpretendY,
DWTMode m, ExtensionDirection2D dir, bool pad_zero)
: extDimX( x1 ), extDimY( y1 ),
sigDimX( x2 ), sigDimY( y2 ),
sigPretendDimX( x3 ), sigPretendDimY( y3 ),
sigStartDimX( x4 ), sigStartDimY( y4 ),
:
extDimX( extdimX ), extDimY( extdimY ),
sigDimX( sigdimX ), sigDimY( sigdimY ),
sigStartX( sigstartX ), sigStartY( sigstartY ),
sigPretendDimX( sigpretendX ), sigPretendDimY( sigpretendY ),
mode(m), direction( dir ), padZero( pad_zero ) {}
// Index translation helper
@ -502,11 +698,10 @@ private:
const ExtensionDirection2D direction;
const bool padZero; // treat sigIn as having a column/row zeros
};
#endif
// !!! useful code above !!!
// Worklet for 2D signal extension
#if 0
class ExtensionWorklet2D : public vtkm::worklet::WorkletMapField
{
public:
@ -597,7 +792,7 @@ private:
const ExtensionDirection2D direction;
const bool padZero; // treat sigIn as having a column/row zeros
};
#endif
// Worklet: perform a simple forward transform
template< typename DeviceTag >