Enable SortByKey Test on the Device Adapter.

This commit is contained in:
Robert Maynard 2015-04-23 13:25:37 -04:00
parent e201778cc0
commit 6564d7af1c
3 changed files with 121 additions and 64 deletions

@ -339,7 +339,7 @@ private:
template<class KeysPortal, class ValuesPortal>
VTKM_CONT_EXPORT static void SortByKeyPortal(const KeysPortal &keys,
const ValuesPortal &values)
const ValuesPortal &values)
{
::thrust::sort_by_key(IteratorBegin(keys),
IteratorEnd(keys),
@ -348,8 +348,8 @@ private:
template<class KeysPortal, class ValuesPortal, class Compare>
VTKM_CONT_EXPORT static void SortByKeyPortal(const KeysPortal &keys,
const ValuesPortal &values,
Compare comp)
const ValuesPortal &values,
Compare comp)
{
::thrust::sort_by_key(IteratorBegin(keys),
IteratorEnd(keys),

@ -23,6 +23,7 @@
#include <vtkm/cont/ArrayHandle.h>
#include <vtkm/cont/ArrayHandleCounting.h>
#include <vtkm/cont/ArrayPortalToIterators.h>
#include <vtkm/cont/ArrayHandleZip.h>
#include <vtkm/cont/StorageBasic.h>
#include <vtkm/exec/FunctorBase.h>
@ -588,6 +589,62 @@ public:
DerivedAlgorithm::Sort(values, DefaultCompareFunctor());
}
//--------------------------------------------------------------------------
// Sort by Key
private:
template<typename T, typename U, class Compare=DefaultCompareFunctor>
struct KeyCompare
{
KeyCompare(): CompareFunctor() {}
explicit KeyCompare(Compare c): CompareFunctor(c) {}
VTKM_EXEC_EXPORT
bool operator()(const vtkm::Pair<T,U>& a, const vtkm::Pair<T,U>& b) const
{
return CompareFunctor(a.first,b.first);
}
private:
Compare CompareFunctor;
};
public:
template<typename T, typename U, class StorageT, class StorageU>
VTKM_CONT_EXPORT static void SortByKey(
vtkm::cont::ArrayHandle<T,StorageT> &keys,
vtkm::cont::ArrayHandle<U,StorageU> &values)
{
//combine the keys and values into a ZipArrayHandle
//we than need to specify a custom compare function wrapper
//that only checks for key side of the pair, using a custom compare functor.
typedef vtkm::cont::ArrayHandle<T,StorageT> KeyType;
typedef vtkm::cont::ArrayHandle<U,StorageU> ValueType;
typedef vtkm::cont::ArrayHandleZip<KeyType,ValueType> ZipHandleType;
ZipHandleType zipHandle =
vtkm::cont::make_ArrayHandleZip(keys,values);
DerivedAlgorithm::Sort(zipHandle,KeyCompare<T,U>());
}
template<typename T, typename U, class StorageT, class StorageU, class Compare>
VTKM_CONT_EXPORT static void SortByKey(
vtkm::cont::ArrayHandle<T,StorageT> &keys,
vtkm::cont::ArrayHandle<U,StorageU> &values,
Compare comp)
{
//combine the keys and values into a ZipArrayHandle
//we than need to specify a custom compare function wrapper
//that only checks for key side of the pair, using the custom compare
//functor that the user passed in
typedef vtkm::cont::ArrayHandle<T,StorageT> KeyType;
typedef vtkm::cont::ArrayHandle<U,StorageU> ValueType;
typedef vtkm::cont::ArrayHandleZip<KeyType,ValueType> ZipHandleType;
ZipHandleType zipHandle =
vtkm::cont::make_ArrayHandleZip(keys,values);
DerivedAlgorithm::Sort(zipHandle,KeyCompare<T,U,Compare>(comp));
}
//--------------------------------------------------------------------------
// Stream Compact
private:

@ -882,76 +882,76 @@ private:
}
}
// static VTKM_CONT_EXPORT void TestSortByKey()
// {
// std::cout << "-------------------------------------------------" << std::endl;
// std::cout << "Sort by keys" << std::endl;
static VTKM_CONT_EXPORT void TestSortByKey()
{
std::cout << "-------------------------------------------------" << std::endl;
std::cout << "Sort by keys" << std::endl;
// vtkm::Id testKeys[ARRAY_SIZE];
// vtkm::Vector3 testValues[ARRAY_SIZE];
vtkm::Id testKeys[ARRAY_SIZE];
vtkm::Vec<FloatDefault,3> testValues[ARRAY_SIZE];
// vtkm::Vector3 grad(1.0,1.0,1.0);
// for(vtkm::Id i=0; i < ARRAY_SIZE; ++i)
// {
// testKeys[i] = ARRAY_SIZE - i;
// testValues[i] = vtkm::Vector3(i);
// }
vtkm::Vec<FloatDefault,3> grad(1.0,1.0,1.0);
for(vtkm::Id i=0; i < ARRAY_SIZE; ++i)
{
testKeys[i] = ARRAY_SIZE - i;
testValues[i] = vtkm::Vec<FloatDefault,3>(i);
}
// IdArrayHandle keys = MakeArrayHandle(testKeys, ARRAY_SIZE);
// Vec3ArrayHandle values = MakeArrayHandle(testValues, ARRAY_SIZE);
IdArrayHandle keys = MakeArrayHandle(testKeys, ARRAY_SIZE);
Vec3ArrayHandle values = MakeArrayHandle(testValues, ARRAY_SIZE);
// IdArrayHandle sorted_keys;
// Vec3ArrayHandle sorted_values;
IdArrayHandle sorted_keys;
Vec3ArrayHandle sorted_values;
// Algorithm::Copy(keys,sorted_keys);
// Algorithm::Copy(values,sorted_values);
Algorithm::Copy(keys,sorted_keys);
Algorithm::Copy(values,sorted_values);
// Algorithm::SortByKey(sorted_keys,sorted_values);
// for(vtkm::Id i=0; i < ARRAY_SIZE; ++i)
// {
// //keys should be sorted from 1 to ARRAY_SIZE
// //values should be sorted from (ARRAY_SIZE-1) to 0
// vtkm::FloatDefault sorted_value =
// sorted_values.GetPortalConstControl().Get(i)[0];
// vtkm::Id sorted_key = sorted_keys.GetPortalConstControl().Get(i);
Algorithm::SortByKey(sorted_keys,sorted_values);
for(vtkm::Id i=0; i < ARRAY_SIZE; ++i)
{
//keys should be sorted from 1 to ARRAY_SIZE
//values should be sorted from (ARRAY_SIZE-1) to 0
vtkm::FloatDefault sorted_value =
sorted_values.GetPortalConstControl().Get(i)[0];
vtkm::Id sorted_key = sorted_keys.GetPortalConstControl().Get(i);
// VTKM_TEST_ASSERT( (sorted_key == (i+1)) , "Got bad SortByKeys key");
// VTKM_TEST_ASSERT( (sorted_value == (ARRAY_SIZE-1-i)),
// "Got bad SortByKeys value");
// }
VTKM_TEST_ASSERT( (sorted_key == (i+1)) , "Got bad SortByKeys key");
VTKM_TEST_ASSERT( (sorted_value == (ARRAY_SIZE-1-i)),
"Got bad SortByKeys value");
}
// // this will return everything back to what it was before sorting
// Algorithm::SortByKey(sorted_keys,sorted_values,comparison::SortGreater());
// for(vtkm::Id i=0; i < ARRAY_SIZE; ++i)
// {
// //keys should be sorted from ARRAY_SIZE to 1
// //values should be sorted from 0 to (ARRAY_SIZE-1)
// vtkm::FloatDefault sorted_value =
// sorted_values.GetPortalConstControl().Get(i)[0];
// vtkm::Id sorted_key = sorted_keys.GetPortalConstControl().Get(i);
// this will return everything back to what it was before sorting
Algorithm::SortByKey(sorted_keys,sorted_values,comparison::SortGreater());
for(vtkm::Id i=0; i < ARRAY_SIZE; ++i)
{
//keys should be sorted from ARRAY_SIZE to 1
//values should be sorted from 0 to (ARRAY_SIZE-1)
vtkm::FloatDefault sorted_value =
sorted_values.GetPortalConstControl().Get(i)[0];
vtkm::Id sorted_key = sorted_keys.GetPortalConstControl().Get(i);
// VTKM_TEST_ASSERT( (sorted_key == (ARRAY_SIZE-i)),
// "Got bad SortByKeys key");
// VTKM_TEST_ASSERT( (sorted_value == i),
// "Got bad SortByKeys value");
// }
VTKM_TEST_ASSERT( (sorted_key == (ARRAY_SIZE-i)),
"Got bad SortByKeys key");
VTKM_TEST_ASSERT( (sorted_value == i),
"Got bad SortByKeys value");
}
// //this is here to verify we can sort by vtkm::Tuples
// Algorithm::SortByKey(sorted_values,sorted_keys);
// for(vtkm::Id i=0; i < ARRAY_SIZE; ++i)
// {
// //keys should be sorted from ARRAY_SIZE to 1
// //values should be sorted from 0 to (ARRAY_SIZE-1)
// vtkm::FloatDefault sorted_value =
// sorted_values.GetPortalConstControl().Get(i)[0];
// vtkm::Id sorted_key = sorted_keys.GetPortalConstControl().Get(i);
//this is here to verify we can sort by vtkm::Tuples
Algorithm::SortByKey(sorted_values,sorted_keys);
for(vtkm::Id i=0; i < ARRAY_SIZE; ++i)
{
//keys should be sorted from ARRAY_SIZE to 1
//values should be sorted from 0 to (ARRAY_SIZE-1)
vtkm::FloatDefault sorted_value =
sorted_values.GetPortalConstControl().Get(i)[0];
vtkm::Id sorted_key = sorted_keys.GetPortalConstControl().Get(i);
// VTKM_TEST_ASSERT( (sorted_key == (ARRAY_SIZE-i)),
// "Got bad SortByKeys key");
// VTKM_TEST_ASSERT( (sorted_value == i),
// "Got bad SortByKeys value");
// }
// }
VTKM_TEST_ASSERT( (sorted_key == (ARRAY_SIZE-i)),
"Got bad SortByKeys key");
VTKM_TEST_ASSERT( (sorted_value == i),
"Got bad SortByKeys value");
}
}
static VTKM_CONT_EXPORT void TestLowerBoundsWithComparisonObject()
{
@ -1344,7 +1344,7 @@ private:
TestScanExclusive();
TestSort();
TestSortWithComparisonObject();
// // TestSortByKey();
TestSortByKey();
TestLowerBoundsWithComparisonObject();
TestUpperBoundsWithComparisonObject();
TestUniqueWithComparisonObject();
@ -1363,7 +1363,7 @@ private:
public:
/// Run a suite of tests to check to see if a DeviceAdapter properly supports
/// all members and classes required for driving Dax algorithms. Returns an
/// all members and classes required for driving vtkm algorithms. Returns an
/// error code that can be returned from the main function of a test.
///
static VTKM_CONT_EXPORT int Run()