reorganize class orders

This commit is contained in:
Samuel Li 2016-12-06 17:47:38 -07:00
parent b76280c849
commit 6db60017d8

@ -260,6 +260,117 @@ private:
// Worklet for 2D signal extension
// This implementation operates on a specified part of a big rectangle
class ExtensionWorklet2D : public vtkm::worklet::WorkletMapField
{
public:
typedef void ControlSignature( WholeArrayOut < ScalarAll >, // extension part
WholeArrayIn < ScalarAll > ); // signal part
typedef void ExecutionSignature( _1, _2, WorkIndex );
typedef _1 InputDomain;
// Constructor
VTKM_EXEC_CONT
ExtensionWorklet2D ( vtkm::Id extdimX, vtkm::Id extdimY,
vtkm::Id sigdimX, vtkm::Id sigdimY,
vtkm::Id sigstartX, vtkm::Id sigstartY,
vtkm::Id sigpretendX, vtkm::Id sigpretendY,
DWTMode m, ExtensionDirection2D dir, bool pad_zero)
:
extDimX( extdimX ), extDimY( extdimY ),
sigDimX( sigdimX ), sigDimY( sigdimY ),
sigStartX( sigstartX ), sigStartY( sigstartY ),
sigPretendDimX( sigpretendX ), sigPretendDimY( sigpretendY ),
mode(m), direction( dir ), padZero( pad_zero )
{ (void)sigDimY; }
// Index translation helper
VTKM_EXEC_CONT
void Ext1Dto2D ( vtkm::Id idx, vtkm::Id &x, vtkm::Id &y ) const
{
x = idx % extDimX;
y = idx / extDimX;
}
// Index translation helper
VTKM_EXEC_CONT
vtkm::Id Sig2Dto1D( vtkm::Id x, vtkm::Id y ) const
{
return y * sigDimX + x;
}
// Index translation helper
VTKM_EXEC_CONT
vtkm::Id SigPretend2Dto1D( vtkm::Id x, vtkm::Id y ) const
{
return (y + sigStartY) * sigDimX + x + sigStartX;
}
template< typename PortalOutType, typename PortalInType >
VTKM_EXEC
void operator()( PortalOutType &portalOut,
const PortalInType &portalIn,
const vtkm::Id &workIndex) const
{
vtkm::Id extX, extY, sigPretendX, sigPretendY;
Ext1Dto2D( workIndex, extX, extY );
typename PortalOutType::ValueType sym = 1.0;
if( mode == ASYMH || mode == ASYMW )
sym = -1.0;
if( direction == LEFT )
{
sigPretendY = extY;
if( mode == SYMH || mode == ASYMH )
sigPretendX = extDimX - extX - 1;
else // mode == SYMW || mode == ASYMW
sigPretendX = extDimX - extX;
}
else if( direction == TOP )
{
sigPretendX = extX;
if( mode == SYMH || mode == ASYMH )
sigPretendY = extDimY - extY - 1;
else // mode == SYMW || mode == ASYMW
sigPretendY = extDimY - extY;
}
else if( direction == RIGHT )
{
sigPretendY = extY;
if( mode == SYMH || mode == ASYMH )
sigPretendX = sigPretendDimX - extX - 1;
else
sigPretendX = sigPretendDimX - extX - 2;
if( padZero )
sigPretendX++;
}
else // direction == BOTTOM
{
sigPretendX = extX;
if( mode == SYMH || mode == ASYMH )
sigPretendY = sigPretendDimY - extY - 1;
else
sigPretendY = sigPretendDimY - extY - 2;
if( padZero )
sigPretendY++;
}
if( sigPretendX == sigPretendDimX || sigPretendY == sigPretendDimY )
portalOut.Set( workIndex, 0.0 );
else
portalOut.Set( workIndex, sym *
portalIn.Get( SigPretend2Dto1D(sigPretendX, sigPretendY) ));
}
private:
const vtkm::Id extDimX, extDimY, sigDimX, sigDimY;
const vtkm::Id sigStartX, sigStartY, sigPretendDimX, sigPretendDimY;
const DWTMode mode;
const ExtensionDirection2D direction;
const bool padZero; // treat sigIn as having a column/row zeros
};
// Worklet: perform a simple 2D forward transform
template< typename DeviceTag >
class ForwardTransform2D: public vtkm::worklet::WorkletMapField
@ -431,194 +542,6 @@ private:
// Worklet for 2D signal extension
// This implementation operates on a specified part of a big rectangle
class ExtensionWorklet2D : public vtkm::worklet::WorkletMapField
{
public:
typedef void ControlSignature( WholeArrayOut < ScalarAll >, // extension part
WholeArrayIn < ScalarAll > ); // signal part
typedef void ExecutionSignature( _1, _2, WorkIndex );
typedef _1 InputDomain;
// Constructor
VTKM_EXEC_CONT
ExtensionWorklet2D ( vtkm::Id extdimX, vtkm::Id extdimY,
vtkm::Id sigdimX, vtkm::Id sigdimY,
vtkm::Id sigstartX, vtkm::Id sigstartY,
vtkm::Id sigpretendX, vtkm::Id sigpretendY,
DWTMode m, ExtensionDirection2D dir, bool pad_zero)
:
extDimX( extdimX ), extDimY( extdimY ),
sigDimX( sigdimX ), sigDimY( sigdimY ),
sigStartX( sigstartX ), sigStartY( sigstartY ),
sigPretendDimX( sigpretendX ), sigPretendDimY( sigpretendY ),
mode(m), direction( dir ), padZero( pad_zero )
{ (void)sigDimY; }
// Index translation helper
VTKM_EXEC_CONT
void Ext1Dto2D ( vtkm::Id idx, vtkm::Id &x, vtkm::Id &y ) const
{
x = idx % extDimX;
y = idx / extDimX;
}
// Index translation helper
VTKM_EXEC_CONT
vtkm::Id Sig2Dto1D( vtkm::Id x, vtkm::Id y ) const
{
return y * sigDimX + x;
}
// Index translation helper
VTKM_EXEC_CONT
vtkm::Id SigPretend2Dto1D( vtkm::Id x, vtkm::Id y ) const
{
return (y + sigStartY) * sigDimX + x + sigStartX;
}
template< typename PortalOutType, typename PortalInType >
VTKM_EXEC
void operator()( PortalOutType &portalOut,
const PortalInType &portalIn,
const vtkm::Id &workIndex) const
{
vtkm::Id extX, extY, sigPretendX, sigPretendY;
Ext1Dto2D( workIndex, extX, extY );
typename PortalOutType::ValueType sym = 1.0;
if( mode == ASYMH || mode == ASYMW )
sym = -1.0;
if( direction == LEFT )
{
sigPretendY = extY;
if( mode == SYMH || mode == ASYMH )
sigPretendX = extDimX - extX - 1;
else // mode == SYMW || mode == ASYMW
sigPretendX = extDimX - extX;
}
else if( direction == TOP )
{
sigPretendX = extX;
if( mode == SYMH || mode == ASYMH )
sigPretendY = extDimY - extY - 1;
else // mode == SYMW || mode == ASYMW
sigPretendY = extDimY - extY;
}
else if( direction == RIGHT )
{
sigPretendY = extY;
if( mode == SYMH || mode == ASYMH )
sigPretendX = sigPretendDimX - extX - 1;
else
sigPretendX = sigPretendDimX - extX - 2;
if( padZero )
sigPretendX++;
}
else // direction == BOTTOM
{
sigPretendX = extX;
if( mode == SYMH || mode == ASYMH )
sigPretendY = sigPretendDimY - extY - 1;
else
sigPretendY = sigPretendDimY - extY - 2;
if( padZero )
sigPretendY++;
}
if( sigPretendX == sigPretendDimX || sigPretendY == sigPretendDimY )
portalOut.Set( workIndex, 0.0 );
else
portalOut.Set( workIndex, sym *
portalIn.Get( SigPretend2Dto1D(sigPretendX, sigPretendY) ));
}
private:
const vtkm::Id extDimX, extDimY, sigDimX, sigDimY;
const vtkm::Id sigStartX, sigStartY, sigPretendDimX, sigPretendDimY;
const DWTMode mode;
const ExtensionDirection2D direction;
const bool padZero; // treat sigIn as having a column/row zeros
};
// Worklet: perform a simple 1D forward transform
template< typename DeviceTag >
class ForwardTransform: public vtkm::worklet::WorkletMapField
{
public:
typedef void ControlSignature(WholeArrayIn<ScalarAll>, // sigIn
WholeArrayOut<ScalarAll>); // cA followed by cD
typedef void ExecutionSignature(_1, _2, WorkIndex);
typedef _1 InputDomain;
// Constructor
VTKM_EXEC_CONT
ForwardTransform( const vtkm::cont::ArrayHandle<vtkm::Float64> &loFilter,
const vtkm::cont::ArrayHandle<vtkm::Float64> &hiFilter,
vtkm::Id filLen, vtkm::Id approx_len, vtkm::Id detail_len,
bool odd_low, bool odd_high ) :
lowFilter( loFilter.PrepareForInput(DeviceTag()) ),
highFilter( hiFilter.PrepareForInput(DeviceTag()) ),
filterLen( filLen ),
approxLen( approx_len ),
detailLen( detail_len ),
oddlow ( odd_low ),
oddhigh ( odd_high )
{ this->SetStartPosition(); }
// Use 64-bit float for convolution calculation
#define VAL vtkm::Float64
#define MAKEVAL(a) (static_cast<VAL>(a))
template <typename InputPortalType, typename OutputPortalType>
VTKM_EXEC
void operator()(const InputPortalType &signalIn,
OutputPortalType &coeffOut,
const vtkm::Id &workIndex) const
{
typedef typename OutputPortalType::ValueType OutputValueType;
if( workIndex < approxLen + detailLen )
{
if( workIndex % 2 == 0 ) // calculate cA
{
vtkm::Id xl = xlstart + workIndex;
VAL sum=MAKEVAL(0.0);
for( vtkm::Id k = filterLen - 1; k >= 0; k-- )
sum += lowFilter.Get(k) * MAKEVAL( signalIn.Get(xl++) );
vtkm::Id outputIdx = workIndex / 2; // put cA at the beginning
coeffOut.Set( outputIdx, static_cast<OutputValueType>(sum) );
}
else // calculate cD
{
VAL sum=MAKEVAL(0.0);
vtkm::Id xh = xhstart + workIndex - 1;
for( vtkm::Id k = filterLen - 1; k >= 0; k-- )
sum += highFilter.Get(k) * MAKEVAL( signalIn.Get(xh++) );
vtkm::Id outputIdx = approxLen + (workIndex-1) / 2; // put cD after cA
coeffOut.Set( outputIdx, static_cast<OutputValueType>(sum) );
}
}
}
#undef MAKEVAL
#undef VAL
private:
const typename vtkm::cont::ArrayHandle<vtkm::Float64>::
ExecutionTypes<DeviceTag>::PortalConst lowFilter, highFilter;
const vtkm::Id filterLen, approxLen, detailLen; // filter and outcome coeff length.
bool oddlow, oddhigh;
vtkm::Id xlstart, xhstart;
VTKM_EXEC_CONT
void SetStartPosition()
{
this->xlstart = this->oddlow ? 1 : 0;
this->xhstart = this->oddhigh ? 1 : 0;
}
};
// ---------------------------------------------------
// | | | | | | |
// | | | | | | |
@ -889,6 +812,83 @@ private:
// Worklet: perform a simple 1D forward transform
template< typename DeviceTag >
class ForwardTransform: public vtkm::worklet::WorkletMapField
{
public:
typedef void ControlSignature(WholeArrayIn<ScalarAll>, // sigIn
WholeArrayOut<ScalarAll>); // cA followed by cD
typedef void ExecutionSignature(_1, _2, WorkIndex);
typedef _1 InputDomain;
// Constructor
VTKM_EXEC_CONT
ForwardTransform( const vtkm::cont::ArrayHandle<vtkm::Float64> &loFilter,
const vtkm::cont::ArrayHandle<vtkm::Float64> &hiFilter,
vtkm::Id filLen, vtkm::Id approx_len, vtkm::Id detail_len,
bool odd_low, bool odd_high ) :
lowFilter( loFilter.PrepareForInput(DeviceTag()) ),
highFilter( hiFilter.PrepareForInput(DeviceTag()) ),
filterLen( filLen ),
approxLen( approx_len ),
detailLen( detail_len ),
oddlow ( odd_low ),
oddhigh ( odd_high )
{ this->SetStartPosition(); }
// Use 64-bit float for convolution calculation
#define VAL vtkm::Float64
#define MAKEVAL(a) (static_cast<VAL>(a))
template <typename InputPortalType, typename OutputPortalType>
VTKM_EXEC
void operator()(const InputPortalType &signalIn,
OutputPortalType &coeffOut,
const vtkm::Id &workIndex) const
{
typedef typename OutputPortalType::ValueType OutputValueType;
if( workIndex < approxLen + detailLen )
{
if( workIndex % 2 == 0 ) // calculate cA
{
vtkm::Id xl = xlstart + workIndex;
VAL sum=MAKEVAL(0.0);
for( vtkm::Id k = filterLen - 1; k >= 0; k-- )
sum += lowFilter.Get(k) * MAKEVAL( signalIn.Get(xl++) );
vtkm::Id outputIdx = workIndex / 2; // put cA at the beginning
coeffOut.Set( outputIdx, static_cast<OutputValueType>(sum) );
}
else // calculate cD
{
VAL sum=MAKEVAL(0.0);
vtkm::Id xh = xhstart + workIndex - 1;
for( vtkm::Id k = filterLen - 1; k >= 0; k-- )
sum += highFilter.Get(k) * MAKEVAL( signalIn.Get(xh++) );
vtkm::Id outputIdx = approxLen + (workIndex-1) / 2; // put cD after cA
coeffOut.Set( outputIdx, static_cast<OutputValueType>(sum) );
}
}
}
#undef MAKEVAL
#undef VAL
private:
const typename vtkm::cont::ArrayHandle<vtkm::Float64>::
ExecutionTypes<DeviceTag>::PortalConst lowFilter, highFilter;
const vtkm::Id filterLen, approxLen, detailLen; // filter and outcome coeff length.
bool oddlow, oddhigh;
vtkm::Id xlstart, xhstart;
VTKM_EXEC_CONT
void SetStartPosition()
{
this->xlstart = this->oddlow ? 1 : 0;
this->xhstart = this->oddhigh ? 1 : 0;
}
};
// Worklet: perform an 1D inverse transform for odd length, symmetric filters.
template< typename DeviceTag >
class InverseTransformOdd: public vtkm::worklet::WorkletMapField