InverseTransform2D takes filters in constructor.

This commit is contained in:
Samuel Li 2016-09-08 17:35:17 -06:00
parent b813c4c9c5
commit 3ef6d284cf
2 changed files with 46 additions and 48 deletions

@ -923,8 +923,8 @@ if( print)
VTKM_ASSERT( inDimX * inDimY == coeffIn.GetNumberOfValues() );
vtkm::Id filterLen = WaveletBase::filter.GetFilterLength();
typedef vtkm::worklet::wavelets::InverseTransform2DOdd IdwtOddWorklet;
typedef vtkm::worklet::wavelets::InverseTransform2DEven IdwtEvenWorklet;
typedef vtkm::worklet::wavelets::InverseTransform2DOdd<DeviceTag> IdwtOddWorklet;
typedef vtkm::worklet::wavelets::InverseTransform2DEven<DeviceTag> IdwtEvenWorklet;
vtkm::Float64 elapsedTime = 0.0;
// First inverse transform on columns
@ -948,28 +948,26 @@ if( print)
afterY.PrepareForOutput( afterYDimX * afterYDimY, DeviceTag() );
if( filterLen % 2 != 0 )
{
IdwtOddWorklet worklet( filterLen, beforeYExtendDimX, beforeYExtendDimY,
IdwtOddWorklet worklet( WaveletBase::filter.GetLowReconstructFilter(),
WaveletBase::filter.GetHighReconstructFilter(),
filterLen, beforeYExtendDimX, beforeYExtendDimY,
afterYDimX, afterYDimY, cATempLen );
vtkm::worklet::DispatcherMapField<IdwtOddWorklet, DeviceTag>
dispatcher( worklet );
vtkm::cont::Timer<DeviceTag> timer;
dispatcher.Invoke( beforeYExtend,
WaveletBase::filter.GetLowReconstructFilter(),
WaveletBase::filter.GetHighReconstructFilter(),
afterY );
dispatcher.Invoke( beforeYExtend, afterY );
elapsedTime += timer.GetElapsedTime();
}
else
{
IdwtEvenWorklet worklet( filterLen, beforeYExtendDimX, beforeYExtendDimY,
IdwtEvenWorklet worklet( WaveletBase::filter.GetLowReconstructFilter(),
WaveletBase::filter.GetHighReconstructFilter(),
filterLen, beforeYExtendDimX, beforeYExtendDimY,
afterYDimX, afterYDimY, cATempLen );
vtkm::worklet::DispatcherMapField<IdwtEvenWorklet, DeviceTag>
dispatcher( worklet );
vtkm::cont::Timer<DeviceTag> timer;
dispatcher.Invoke( beforeYExtend,
WaveletBase::filter.GetLowReconstructFilter(),
WaveletBase::filter.GetHighReconstructFilter(),
afterY );
dispatcher.Invoke( beforeYExtend, afterY );
elapsedTime += timer.GetElapsedTime();
}
@ -997,28 +995,26 @@ if( print)
afterX.PrepareForOutput( afterXDimX * afterXDimY, DeviceTag() );
if( filterLen % 2 != 0 )
{
IdwtOddWorklet worklet( filterLen, beforeXExtendDimX, beforeXExtendDimY,
IdwtOddWorklet worklet( WaveletBase::filter.GetLowReconstructFilter(),
WaveletBase::filter.GetHighReconstructFilter(),
filterLen, beforeXExtendDimX, beforeXExtendDimY,
afterXDimX, afterXDimY, cATempLen );
vtkm::worklet::DispatcherMapField<IdwtOddWorklet, DeviceTag>
dispatcher( worklet );
vtkm::cont::Timer<DeviceTag> timer;
dispatcher.Invoke( beforeXExtend,
WaveletBase::filter.GetLowReconstructFilter(),
WaveletBase::filter.GetHighReconstructFilter(),
afterX );
dispatcher.Invoke( beforeXExtend, afterX );
elapsedTime += timer.GetElapsedTime();
}
else
{
IdwtEvenWorklet worklet( filterLen, beforeXExtendDimX, beforeXExtendDimY,
IdwtEvenWorklet worklet( WaveletBase::filter.GetLowReconstructFilter(),
WaveletBase::filter.GetHighReconstructFilter(),
filterLen, beforeXExtendDimX, beforeXExtendDimY,
afterXDimX, afterXDimY, cATempLen );
vtkm::worklet::DispatcherMapField<IdwtEvenWorklet, DeviceTag>
dispatcher( worklet );
vtkm::cont::Timer<DeviceTag> timer;
dispatcher.Invoke( beforeXExtend,
WaveletBase::filter.GetLowReconstructFilter(),
WaveletBase::filter.GetHighReconstructFilter(),
afterX );
dispatcher.Invoke( beforeXExtend, afterX );
elapsedTime += timer.GetElapsedTime();
}

@ -234,21 +234,24 @@ private:
// Worklet: perform a simple 2D inverse transform on odd length filters
template< typename DeviceTag >
class InverseTransform2DOdd: public vtkm::worklet::WorkletMapField
{
public:
typedef void ControlSignature( WholeArrayIn< ScalarAll >, // input extended signal
WholeArrayIn< Scalar >,
WholeArrayIn< Scalar >,
FieldOut< ScalarAll> ); // outptu coeffs
typedef void ExecutionSignature( _1, _2, _3, _4, WorkIndex );
typedef _4 InputDomain;
typedef void ExecutionSignature( _1, _2, WorkIndex );
typedef _2 InputDomain;
// Constructor
VTKM_EXEC_CONT_EXPORT
InverseTransform2DOdd( vtkm::Id fil_len, vtkm::Id x1, vtkm::Id y1, vtkm::Id x2,
vtkm::Id y2, vtkm::Id cA_len_ext )
: filterLen( fil_len ), inputDimX( x1 ), inputDimY( y1 ),
InverseTransform2DOdd( const vtkm::cont::ArrayHandle<vtkm::Float64> &lo_fil,
const vtkm::cont::ArrayHandle<vtkm::Float64> &hi_fil,
vtkm::Id fil_len, vtkm::Id x1, vtkm::Id y1, vtkm::Id x2,
vtkm::Id y2, vtkm::Id cA_len_ext ) :
lowFilter( lo_fil.PrepareForInput( DeviceTag() ) ),
highFilter( hi_fil.PrepareForInput( DeviceTag() ) ),
filterLen( fil_len ), inputDimX( x1 ), inputDimY( y1 ),
outputDimX( x2 ), outputDimY( y2 ), cALenExtended( cA_len_ext ) {}
VTKM_EXEC_CONT_EXPORT
@ -267,13 +270,9 @@ public:
#define VAL vtkm::Float64
#define MAKEVAL(a) (static_cast<VAL>(a))
template< typename InputPortalType,
typename FilterPortalType,
typename OutputValueType >
template< typename InputPortalType, typename OutputValueType >
VTKM_EXEC_EXPORT
void operator() (const InputPortalType &sigIn,
const FilterPortalType &loFilter,
const FilterPortalType &hiFilter,
OutputValueType &coeffOut,
const vtkm::Id &workIdx ) const
{
@ -300,7 +299,7 @@ public:
while( k1 > -1 )
{
sigIdx1D = Input2Dto1D( xi, inY );
sum += loFilter.Get(k1) * MAKEVAL( sigIn.Get( sigIdx1D ) );
sum += lowFilter.Get(k1) * MAKEVAL( sigIn.Get( sigIdx1D ) );
xi++;
k1 -= 2;
}
@ -308,7 +307,7 @@ public:
while( k2 > -1 )
{
sigIdx1D = Input2Dto1D( xi + cALenExtended, inY );
sum += hiFilter.Get(k2) * MAKEVAL( sigIn.Get( sigIdx1D ) );
sum += highFilter.Get(k2) * MAKEVAL( sigIn.Get( sigIdx1D ) );
xi++;
k2 -= 2;
}
@ -319,6 +318,8 @@ public:
#undef VAL
private:
typename vtkm::cont::ArrayHandle<vtkm::Float64>::ExecutionTypes<DeviceTag>::PortalConst
lowFilter, highFilter;
vtkm::Id filterLen;
vtkm::Id inputDimX, inputDimY, outputDimX, outputDimY;
vtkm::Id cALenExtended; // Number of cA at the beginning of input, followed by cD
@ -406,23 +407,26 @@ private:
// Worklet: perform an inverse transform for even length, symmetric filters.
template< typename DeviceTag >
class InverseTransform2DEven: public vtkm::worklet::WorkletMapField
{
public:
typedef void ControlSignature(WholeArrayIn<ScalarAll>, // Input: coeffs,
// cA followed by cD
WholeArrayIn<Scalar>, // lowFilter
WholeArrayIn<Scalar>, // highFilter
FieldOut<ScalarAll>); // output
typedef void ExecutionSignature(_1, _2, _3, _4, WorkIndex);
typedef _4 InputDomain;
typedef void ExecutionSignature(_1, _2, WorkIndex);
typedef _2 InputDomain;
// Constructor
VTKM_EXEC_CONT_EXPORT
InverseTransform2DEven( vtkm::Id filtL, vtkm::Id x1, vtkm::Id y1,
InverseTransform2DEven( const vtkm::cont::ArrayHandle<vtkm::Float64> &lo_fil,
const vtkm::cont::ArrayHandle<vtkm::Float64> &hi_fil,
vtkm::Id filtL, vtkm::Id x1, vtkm::Id y1,
vtkm::Id x2, vtkm::Id y2, vtkm::Id cALExt ) :
filterLen(filtL), inputDimX( x1 ), inputDimY( y1 ),
outputDimX( x2 ), outputDimY( y2 ), cALenExtended(cALExt) {}
lowFilter( lo_fil.PrepareForInput( DeviceTag() ) ),
highFilter( hi_fil.PrepareForInput( DeviceTag() ) ),
filterLen(filtL), inputDimX( x1 ), inputDimY( y1 ),
outputDimX( x2 ), outputDimY( y2 ), cALenExtended(cALExt) {}
VTKM_EXEC_CONT_EXPORT
void Output1Dto2D( const vtkm::Id &idx, vtkm::Id &x, vtkm::Id &y ) const
@ -440,13 +444,9 @@ public:
#define VAL vtkm::Float64
#define MAKEVAL(a) (static_cast<VAL>(a))
template <typename InputPortalType,
typename FilterPortalType,
typename OutputValueType>
template <typename InputPortalType, typename OutputValueType>
VTKM_EXEC_EXPORT
void operator()(const InputPortalType &coeffs,
const FilterPortalType &lowFilter,
const FilterPortalType &highFilter,
OutputValueType &sigOut,
const vtkm::Id &workIndex) const
{
@ -491,6 +491,8 @@ public:
#undef VAL
private:
typename vtkm::cont::ArrayHandle<vtkm::Float64>::ExecutionTypes<DeviceTag>::PortalConst
lowFilter, highFilter;
vtkm::Id filterLen; // filter length.
vtkm::Id inputDimX, inputDimY, outputDimX, outputDimY;
vtkm::Id cALenExtended; // Number of cA at the beginning of input, followed by cD