DWT worklet uses WholeArrayOut to avoid a memory copy

This commit is contained in:
Samuel Li 2016-08-02 14:09:34 -04:00
parent 448e9be5a8
commit d88bdff94c
3 changed files with 114 additions and 25 deletions

@ -195,8 +195,9 @@ public:
WaveletBase::DeviceCopy( coeffIn, sortedArray );
WaveletBase::DeviceSort( sortedArray );
vtkm::Id n = vtkm::Ceil( static_cast<vtkm::Float64>( coeffLen ) /
static_cast<vtkm::Float64>( ratio ) );
vtkm::Id n = static_cast<vtkm::Id>(
vtkm::Ceil( static_cast<vtkm::Float64>( coeffLen ) /
static_cast<vtkm::Float64>( ratio ) ) );
ValueType threshold = sortedArray.GetPortalConstControl().Get( coeffLen - n );
if( threshold < 0.0 )
threshold *= -1.0;

@ -177,8 +177,8 @@ public:
vtkm::Id sigExtendedLen = sigInLen + 2 * addLen;
typedef typename SignalArrayType::ValueType SigInValueType;
typedef vtkm::cont::ArrayHandle<SigInValueType> SignalArrayTypeBasic;
typedef typename SignalArrayType::ValueType SigInValueType;
typedef vtkm::cont::ArrayHandle<SigInValueType> SignalArrayTypeBasic;
SignalArrayTypeBasic sigInExtended;
@ -198,36 +198,20 @@ public:
forwardTransform.SetCoeffLength( L[0], L[1] );
forwardTransform.SetOddness( oddLow, oddHigh );
coeffOut.Allocate( sigInExtended.GetNumberOfValues() );
vtkm::worklet::DispatcherMapField<vtkm::worklet::wavelets::ForwardTransform>
dispatcher(forwardTransform);
dispatcher.Invoke( sigInExtended,
WaveletBase::filter->GetLowDecomposeFilter(),
WaveletBase::filter->GetHighDecomposeFilter(),
coeffOutTmp );
// Separate cA and cD.
typedef vtkm::cont::ArrayHandleCounting< vtkm::Id > IdArrayType;
typedef vtkm::cont::ArrayHandlePermutation< IdArrayType, CoeffArrayType > PermutArrayType;
IdArrayType approxIndices( 0, 2, L[0] );
IdArrayType detailIndices( 1, 2, L[1] );
PermutArrayType cATmp( approxIndices, coeffOutTmp );
PermutArrayType cDTmp( detailIndices, coeffOutTmp );
typedef vtkm::cont::ArrayHandleConcatenate< PermutArrayType, PermutArrayType >
PermutArrayConcatenated;
PermutArrayConcatenated coeffOutConcat( cATmp, cDTmp );
/*
vtkm::cont::DeviceAdapterAlgorithm< VTKM_DEFAULT_DEVICE_ADAPTER_TAG>::Copy(
coeffOutConcat, coeffOut );
*/
WaveletBase::DeviceCopy( coeffOutConcat, coeffOut );
coeffOut );
VTKM_ASSERT( L[0] + L[1] <= coeffOut.GetNumberOfValues() );
coeffOut.Shrink( L[0] + L[1] );
return 0;
}
// Func:
// Performs one level of inverse wavelet transform

@ -30,6 +30,7 @@ namespace vtkm {
namespace worklet {
namespace wavelets {
#if 0
// Worklet: perform a simple forward transform
class ForwardTransform: public vtkm::worklet::WorkletMapField
{
@ -131,6 +132,109 @@ private:
this->xhstart = this->oddhigh ? 1 : 0;
}
}; // Finish class ForwardTransform
#endif
// Worklet: perform a simple forward transform
class ForwardTransform: public vtkm::worklet::WorkletMapField
{
public:
typedef void ControlSignature(WholeArrayIn<ScalarAll>, // sigIn
WholeArrayIn<Scalar>, // lowFilter
WholeArrayIn<Scalar>, // highFilter
WholeArrayOut<ScalarAll>); // cA followed by cD
typedef void ExecutionSignature(_1, _2, _3, _4, WorkIndex);
typedef _1 InputDomain;
// Constructor
VTKM_EXEC_CONT_EXPORT
ForwardTransform()
{
magicNum = 0.0;
oddlow = oddhigh = true;
filterLen = approxLen = detailLen = 0;
this->SetStartPosition();
}
// Specify odd or even for low and high coeffs
VTKM_EXEC_CONT_EXPORT
void SetOddness(bool odd_low, bool odd_high )
{
this->oddlow = odd_low;
this->oddhigh = odd_high;
this->SetStartPosition();
}
// Set the filter length
VTKM_EXEC_CONT_EXPORT
void SetFilterLength( vtkm::Id len )
{
this->filterLen = len;
}
// Set the outcome coefficient length
VTKM_EXEC_CONT_EXPORT
void SetCoeffLength( vtkm::Id approx_len, vtkm::Id detail_len )
{
this->approxLen = approx_len;
this->detailLen = detail_len;
}
// Use 64-bit float for convolution calculation
#define VAL vtkm::Float64
#define MAKEVAL(a) (static_cast<VAL>(a))
template <typename InputPortalType,
typename FilterPortalType,
typename OutputPortalType>
VTKM_EXEC_EXPORT
void operator()(const InputPortalType &signalIn,
const FilterPortalType &lowFilter,
const FilterPortalType &highFilter,
OutputPortalType &coeffOut,
const vtkm::Id &workIndex) const
{
typedef typename OutputPortalType::ValueType OutputValueType;
if( workIndex < approxLen + detailLen )
if( workIndex % 2 == 0 ) // calculate cA
{
vtkm::Id xl = xlstart + workIndex;
VAL sum=MAKEVAL(0.0);
for( vtkm::Id k = filterLen - 1; k >= 0; k-- )
sum += lowFilter.Get(k) * MAKEVAL( signalIn.Get(xl++) );
vtkm::Id outputIdx = workIndex / 2; // put cA at the beginning
coeffOut.Set( outputIdx, static_cast<OutputValueType>(sum) );
}
else // calculate cD
{
VAL sum=MAKEVAL(0.0);
vtkm::Id xh = xhstart + workIndex - 1;
for( vtkm::Id k = filterLen - 1; k >= 0; k-- )
sum += highFilter.Get(k) * MAKEVAL( signalIn.Get(xh++) );
vtkm::Id outputIdx = approxLen + (workIndex-1) / 2; // put cD after cA
coeffOut.Set( outputIdx, static_cast<OutputValueType>(sum) );
}
else
coeffOut.Set( workIndex, static_cast<OutputValueType>( magicNum ) );
}
#undef MAKEVAL
#undef VAL
private:
vtkm::Float64 magicNum;
vtkm::Id filterLen, approxLen, detailLen; // filter and outcome coeff length.
vtkm::Id xlstart, xhstart;
bool oddlow, oddhigh;
VTKM_EXEC_CONT_EXPORT
void SetStartPosition()
{
this->xlstart = this->oddlow ? 1 : 0;
this->xhstart = this->oddhigh ? 1 : 0;
}
}; // Finish class ForwardTransform
// Worklet: perform an inverse transform for odd length, symmetric filters.