add worklet to copy a portion of array

This commit is contained in:
Samuel Li 2016-08-03 11:34:04 -04:00
parent b96d65550f
commit 433f89ae12
3 changed files with 56 additions and 112 deletions

@ -42,9 +42,9 @@ public:
// Multi-level 1D wavelet decomposition
template< typename SignalArrayType, typename CoeffArrayType>
VTKM_CONT_EXPORT
vtkm::Id WaveDecompose( const SignalArrayType &sigIn, // Input
vtkm::Id WaveDecompose( const SignalArrayType &sigIn, // Input
vtkm::Id nLevels, // n levels of DWT
CoeffArrayType &coeffOut,
CoeffArrayType &coeffOut,
vtkm::Id* L )
{
@ -62,11 +62,11 @@ public:
this->ComputeL( sigInLen, nLevels, L );
vtkm::Id CLength = this->ComputeCoeffLength( L, nLevels );
VTKM_ASSERT( CLength == sigIn.GetNumberOfValues() );
VTKM_ASSERT( CLength == sigInLen );
vtkm::Id sigInPtr = 0; // pseudo pointer for the beginning of input array
vtkm::Id len = sigIn.GetNumberOfValues();
vtkm::Id len = sigInLen;
vtkm::Id cALen = WaveletBase::GetApproxLength( len );
vtkm::Id cptr; // pseudo pointer for the beginning of output array
vtkm::Id tlen = 0;
@ -90,20 +90,23 @@ public:
cptr = 0 + CLength - tlen - cALen;
// make input array (permutation array)
IdArrayType inputIndices( sigInPtr, 1, len );
PermutArrayType input( inputIndices, coeffOut );
IdArrayType inputIndices( sigInPtr, 1, len );
PermutArrayType input( inputIndices, coeffOut );
// make output array
InterArrayType output;
InterArrayType output;
WaveletDWT::DWT1D( input, output, L1d );
// update interArray
// move intermediate results to final array
/*
vtkm::cont::ArrayPortalToIterators< InterPortalType >
outputIter( output.GetPortalControl() );
vtkm::cont::ArrayPortalToIterators< InterPortalType >
coeffOutIter( coeffOut.GetPortalControl() );
std::copy( outputIter.GetBegin(), outputIter.GetEnd(),
coeffOutIter.GetBegin() + cptr );
*/
WaveletBase::DeviceCopyStartX( output, coeffOut, cptr );
// update pseudo pointers
len = cALen;
@ -152,11 +155,14 @@ public:
WaveletDWT::IDWT1D( input, L1d, output );
// Move output to intermediate array
/*
vtkm::cont::ArrayPortalToIterators< typename OutArrayBasic::PortalControl >
outputIter( output.GetPortalControl() );
vtkm::cont::ArrayPortalToIterators< typename SignalArrayType::PortalControl >
sigOutIter( sigOut.GetPortalControl() );
std::copy( outputIter.GetBegin(), outputIter.GetEnd(), sigOutIter.GetBegin() );
*/
WaveletBase::DeviceCopyStartX( output, sigOut, 0 );
L1d[0] = L1d[2];
L1d[1] = L[i+1];

@ -171,6 +171,19 @@ public:
( srcArray, dstArray );
}
// perform a device copy. The whole 1st array to a certain start location of the 2nd array
template< typename ArrayType1, typename ArrayType2 >
VTKM_EXEC_CONT_EXPORT
void DeviceCopyStartX( const ArrayType1 &srcArray,
ArrayType2 &dstArray,
vtkm::Id startIdx)
{
typedef vtkm::worklet::wavelets::CopyWorklet CopyType;
CopyType cp( startIdx );
vtkm::worklet::DispatcherMapField< CopyType > dispatcher( cp );
dispatcher.Invoke( srcArray, dstArray );
}
// Sort by the absolute value on device
struct SortLessAbsFunctor
{

@ -30,110 +30,6 @@ namespace vtkm {
namespace worklet {
namespace wavelets {
#if 0
// Worklet: perform a simple forward transform
class ForwardTransform: public vtkm::worklet::WorkletMapField
{
public:
typedef void ControlSignature(WholeArrayIn<ScalarAll>, // sigIn
WholeArrayIn<Scalar>, // lowFilter
WholeArrayIn<Scalar>, // highFilter
FieldOut<ScalarAll>); // cA in even indices,
// cD in odd indices
typedef void ExecutionSignature(_1, _2, _3, _4, WorkIndex);
typedef _1 InputDomain;
// Constructor
VTKM_EXEC_CONT_EXPORT
ForwardTransform()
{
magicNum = 0.0;
oddlow = oddhigh = true;
filterLen = approxLen = detailLen = 0;
this->SetStartPosition();
}
// Specify odd or even for low and high coeffs
VTKM_EXEC_CONT_EXPORT
void SetOddness(bool odd_low, bool odd_high )
{
this->oddlow = odd_low;
this->oddhigh = odd_high;
this->SetStartPosition();
}
// Set the filter length
VTKM_EXEC_CONT_EXPORT
void SetFilterLength( vtkm::Id len )
{
this->filterLen = len;
}
// Set the outcome coefficient length
VTKM_EXEC_CONT_EXPORT
void SetCoeffLength( vtkm::Id approx_len, vtkm::Id detail_len )
{
this->approxLen = approx_len;
this->detailLen = detail_len;
}
// Use 64-bit float for convolution calculation
#define VAL vtkm::Float64
#define MAKEVAL(a) (static_cast<VAL>(a))
template <typename InputSignalPortalType,
typename FilterPortalType,
typename OutputCoeffType>
VTKM_EXEC_EXPORT
void operator()(const InputSignalPortalType &signalIn,
const FilterPortalType &lowFilter,
const FilterPortalType &highFilter,
OutputCoeffType &coeffOut,
const vtkm::Id &workIndex) const
{
if( workIndex % 2 == 0 ) // calculate cA, approximate coeffs
if( workIndex < approxLen + detailLen )
{
vtkm::Id xl = xlstart + workIndex;
VAL sum=MAKEVAL(0.0);
for( vtkm::Id k = filterLen - 1; k >= 0; k-- )
sum += lowFilter.Get(k) * MAKEVAL( signalIn.Get(xl++) );
coeffOut = static_cast<OutputCoeffType>( sum );
}
else
coeffOut = static_cast<OutputCoeffType>( magicNum );
else // calculate cD, detail coeffs
if( workIndex < approxLen + detailLen )
{
VAL sum=MAKEVAL(0.0);
vtkm::Id xh = xhstart + workIndex - 1;
for( vtkm::Id k = filterLen - 1; k >= 0; k-- )
sum += highFilter.Get(k) * MAKEVAL( signalIn.Get(xh++) );
coeffOut = static_cast<OutputCoeffType>( sum );
}
else
coeffOut = static_cast<OutputCoeffType>( magicNum );
}
#undef MAKEVAL
#undef VAL
private:
vtkm::Float64 magicNum;
vtkm::Id filterLen, approxLen, detailLen; // filter and outcome coeff length.
vtkm::Id xlstart, xhstart;
bool oddlow, oddhigh;
VTKM_EXEC_CONT_EXPORT
void SetStartPosition()
{
this->xlstart = this->oddlow ? 1 : 0;
this->xhstart = this->oddhigh ? 1 : 0;
}
}; // Finish class ForwardTransform
#endif
// Worklet: perform a simple forward transform
class ForwardTransform: public vtkm::worklet::WorkletMapField
@ -547,6 +443,35 @@ public:
};
class CopyWorklet : public vtkm::worklet::WorkletMapField
{
public:
typedef void ControlSignature( WholeArrayIn< ScalarAll >,
WholeArrayOut< ScalarAll > );
typedef void ExecutionSignature( _1, _2, WorkIndex );
typedef _1 InputDomain;
// Constructor
VTKM_EXEC_CONT_EXPORT
CopyWorklet( vtkm::Id idx )
{
this->startIdx = idx;
}
template< typename PortalInType, typename PortalOutType >
VTKM_EXEC_CONT_EXPORT
void operator()( const PortalInType &portalIn,
PortalOutType &portalOut,
const vtkm::Id &workIndex) const
{
portalOut.Set( (startIdx + workIndex), portalIn.Get(workIndex) );
}
private:
vtkm::Id startIdx;
};
} // namespace wavelets
} // namespace worlet
} // namespace vtkm