ForwardTransform worklet now have the correct interface

This commit is contained in:
Samuel Li 2016-07-06 15:21:12 -06:00
parent 1f4e847c56
commit 9c1e9574fa
2 changed files with 54 additions and 52 deletions

@ -28,51 +28,49 @@
namespace vtkm {
namespace worklet {
namespace internal {
// template <typename T>
// VTKM_EXEC_EXPORT
// T clamp(const T& val, const T& min, const T& max)
// {
// return vtkm::Min(max, vtkm::Max(min, val));
// }
}
class Wavelets : public vtkm::worklet::WorkletMapField
class Wavelets
{
public:
typedef void ControlSignature(FieldIn<>, FieldOut<>);
typedef _2 ExecutionSignature(_1);
VTKM_CONT_EXPORT
Wavelets() : magicNum(2.0) {}
VTKM_CONT_EXPORT
void SetMagicNum(const vtkm::Float64 &num)
// helper worklet
class ForwardTransform: public vtkm::worklet::WorkletMapField
{
this->magicNum = num;
}
public:
typedef void ControlSignature(WholeArrayIn<ScalarAll>, // sigIn
FieldOut<ScalarAll>, // cA
FieldOut<ScalarAll>); // cD
typedef void ExecutionSignature(_1, _2, _3, WorkIndex);
typedef _1 InputDomain;
// ForwardTransform constructor
VTKM_CONT_EXPORT
ForwardTransform()
{
magicNum = 2.0;
oddlow = true;
oddhigh = true;
}
VTKM_EXEC_EXPORT
vtkm::Float64 operator()(const vtkm::Float64 &inputVal) const
{
return inputVal * this->magicNum;
}
template <typename T, typename ArrayPortalType>
VTKM_EXEC_EXPORT
void operator()(const ArrayPortalType &signalIn,
T &coeffApproximation,
T &coeffDetail,
const vtkm::Id &workIndex) const
{
vtkm::Float64 tmp = static_cast<vtkm::Float64>(signalIn.Get(workIndex));
coeffApproximation = static_cast<T>( tmp / 2.0 );
coeffDetail = static_cast<T>( tmp * 2.0 );
}
template <typename T>
VTKM_EXEC_EXPORT
vtkm::Float64 operator()(const T &inputVal) const
{
return (*this)(static_cast<vtkm::Float64>(inputVal));
}
private:
vtkm::Float64 magicNum;
bool oddlow, oddhigh;
}; // class ForwardTransform
private:
vtkm::Float64 magicNum;
};
}; // class Wavelets
}
} // namespace vtkm::worklet
} // namespace worlet
} // namespace vtkm
#endif // vtk_m_worklet_Wavelets_h

@ -33,26 +33,30 @@ void TestWavelets()
vtkm::Id arraySize = 10;
std::vector<vtkm::Float32> tmpVector;
for( vtkm::Id i = 0; i < arraySize; i++ )
tmpVector.push_back(static_cast<vtkm::Float32>(i));
tmpVector.push_back(static_cast<vtkm::Float32>(i*2));
vtkm::cont::ArrayHandle<vtkm::Float32> input1DArray =
vtkm::cont::make_ArrayHandle(tmpVector);
vtkm::cont::ArrayHandle<vtkm::Float32> output1DArray;
vtkm::cont::ArrayHandle<vtkm::Float32> outputArray1;
outputArray1.Allocate( arraySize );
vtkm::cont::ArrayHandle<vtkm::Float32> outputArray2;
outputArray2.Allocate( arraySize );
vtkm::worklet::Wavelets::ForwardTransform forwardTransform;
vtkm::worklet::DispatcherMapField<vtkm::worklet::Wavelets::ForwardTransform> dispatcher(forwardTransform);
dispatcher.Invoke(input1DArray, outputArray1, outputArray2);
vtkm::worklet::Wavelets waveletsWorklet;
vtkm::worklet::DispatcherMapField<vtkm::worklet::Wavelets>
dispatcher(waveletsWorklet);
dispatcher.Invoke(input1DArray, output1DArray);
std::cerr << "Invoke succeeded" << std::endl;
for (vtkm::Id i = 0; i < output1DArray.GetNumberOfValues(); ++i)
for (vtkm::Id i = 0; i < outputArray1.GetNumberOfValues(); ++i)
{
std::cout<< output1DArray.GetPortalConstControl().Get(i) << std::endl;
VTKM_TEST_ASSERT(
test_equal( output1DArray.GetPortalConstControl().Get(i),
static_cast<vtkm::Float32>(i) * 2.0f ),
"Wrong result for Wavelets worklet");
std::cout<< outputArray1.GetPortalConstControl().Get(i) << ", "
<< outputArray2.GetPortalConstControl().Get(i) << std::endl;
// VTKM_TEST_ASSERT(
// test_equal( output1DArray.GetPortalConstControl().Get(i),
// static_cast<vtkm::Float32>(i) * 2.0f ),
// "Wrong result for Wavelets worklet");
}
}