Refactor DynamicArrayHandle CastAndCall

This is the first step in making a more efficient CastAndCall for
DynamicArrayHandle.
This commit is contained in:
Robert Maynard 2017-11-06 10:12:49 -05:00
parent a2dd575f86
commit 3701776e8d
6 changed files with 166 additions and 124 deletions

@ -73,7 +73,7 @@ struct ListTagJoin : detail::ListRoot
using list = typename detail::ListJoin<typename ListTag1::list, typename ListTag2::list>::type;
};
/// A tag that consits of elements that are found in both tags. This struct
/// A tag that consists of elements that are found in both tags. This struct
/// can be subclassed and still behave like a list tag.
template <typename ListTag1, typename ListTag2>
struct ListTagIntersect : detail::ListRoot
@ -92,6 +92,17 @@ VTKM_CONT void ListForEach(Functor&& f, ListTag)
detail::ListForEachImpl(f, typename ListTag::list());
}
/// Generate a tag that is the cross product of two other tags. The resulting
// a tag has the form of Tag< std::pair<A1,B1>, std::pair<A1,B2> .... >
///
///
template <typename ListTag1, typename ListTag2>
struct ListCrossProduct : detail::ListRoot
{
using list =
typename detail::ListCrossProductImpl<typename ListTag1::list, typename ListTag2::list>::type;
};
/// Checks to see if the given \c Type is in the list pointed to by \c ListTag.
/// There is a static boolean named \c value that is set to true if the type is
/// contained in the list and false otherwise.

@ -72,7 +72,17 @@ struct StorageListTagCoordinateSystemDefault
vtkm::cont::ArrayHandle<vtkm::FloatDefault>>::StorageTag>
{
};
}
}
namespace vtkm
{
template struct ListCrossProduct<::vtkm::TypeListTagFieldVec3,
::vtkm::cont::StorageListTagCoordinateSystemDefault>;
namespace cont
{
using DynamicArrayHandleCoordinateSystem =
vtkm::cont::DynamicArrayHandleBase<VTKM_DEFAULT_COORDINATE_SYSTEM_TYPE_LIST_TAG,
VTKM_DEFAULT_COORDINATE_SYSTEM_STORAGE_LIST_TAG>;

@ -18,6 +18,8 @@
// this software.
//============================================================================
#include <sstream>
#include <typeindex>
#include <vtkm/cont/DynamicArrayHandle.h>
namespace vtkm
@ -34,6 +36,18 @@ PolymorphicArrayHandleContainerBase::PolymorphicArrayHandleContainerBase()
PolymorphicArrayHandleContainerBase::~PolymorphicArrayHandleContainerBase()
{
}
void ThrowCastAndCallException(PolymorphicArrayHandleContainerBase* ptr,
const std::type_info* type,
const std::type_info* storage)
{
std::ostringstream out;
out << "Could not find appropriate cast for array in CastAndCall1.\n"
"Array: ";
ptr->PrintSummary(out);
out << "TypeList: " << type->name() << "\nStorageList: " << storage->name() << "\n";
throw vtkm::cont::ErrorBadValue(out.str());
}
}
}
} // namespace vtkm::cont::detail

@ -31,10 +31,11 @@
#include <vtkm/cont/internal/DynamicTransform.h>
#include <sstream>
namespace vtkm
{
template struct ListCrossProduct<VTKM_DEFAULT_TYPE_LIST_TAG, VTKM_DEFAULT_STORAGE_LIST_TAG>;
namespace cont
{
@ -406,135 +407,70 @@ using DynamicArrayHandle =
namespace detail
{
template <typename Functor, typename Type>
struct DynamicArrayHandleTryStorage
template <typename Functor>
struct ListFunctorWrapper
{
const DynamicArrayHandle* const Array;
const Functor& Function;
bool FoundCast;
VTKM_CONT
DynamicArrayHandleTryStorage(const DynamicArrayHandle& array, const Functor& f)
: Array(&array)
ListFunctorWrapper(bool& called, const Functor& f, PolymorphicArrayHandleContainerBase* c)
: Called(called)
, Container(c)
, Function(f)
, FoundCast(false)
{
}
template <typename Storage>
VTKM_CONT void operator()(Storage)
template <typename T, typename U, typename... Args>
void operator()(std::pair<T, U>&& p, Args&&... args) const
{
this->DoCast(Storage(),
typename vtkm::cont::internal::IsValidArrayHandle<Type, Storage>::type());
using storage = vtkm::cont::internal::Storage<T, U>;
using invalid = typename std::is_base_of<vtkm::cont::internal::UndefinedStorage, storage>::type;
this->run(std::forward<decltype(p)>(p), invalid{}, args...);
}
private:
template <typename Storage>
void DoCast(Storage, std::true_type)
template <typename T, typename U, typename... Args>
void run(std::pair<T, U>&&, std::false_type, Args&&... args) const
{
if (!this->FoundCast && this->Array->template IsTypeAndStorage<Type, Storage>())
if (!this->Called)
{
this->Function(this->Array->template CastToTypeStorage<Type, Storage>());
this->FoundCast = true;
vtkm::cont::ArrayHandle<T, U>* handle = DynamicArrayHandleTryCast<T, U>(this->Container);
if (handle)
{
this->Function(*handle, std::forward<Args>(args)...);
this->Called = true;
}
}
}
template <typename Storage>
void DoCast(Storage, std::false_type)
template <typename T, typename U, typename... Args>
void run(std::pair<T, U>&&, std::true_type, Args&&...) const
{
// This type of array handle cannot exist, so do nothing.
}
void operator=(const DynamicArrayHandleTryStorage<Functor, Type>&) = delete;
};
template <typename Functor, typename StorageList>
struct DynamicArrayHandleTryType
{
const DynamicArrayHandle* const Array;
const Functor& Function;
bool FoundCast;
VTKM_CONT
DynamicArrayHandleTryType(const DynamicArrayHandle& array, const Functor& f)
: Array(&array)
, Function(f)
, FoundCast(false)
{
}
template <typename Type>
VTKM_CONT void operator()(Type)
{
if (this->FoundCast)
{
return;
}
using TryStorageType = DynamicArrayHandleTryStorage<Functor, Type>;
TryStorageType tryStorage = TryStorageType(*this->Array, this->Function);
vtkm::ListForEach(tryStorage, StorageList());
if (tryStorage.FoundCast)
{
this->FoundCast = true;
}
}
private:
void operator=(const DynamicArrayHandleTryType<Functor, StorageList>&) = delete;
bool& Called;
PolymorphicArrayHandleContainerBase* Container;
const Functor& Function;
};
VTKM_CONT_EXPORT void ThrowCastAndCallException(PolymorphicArrayHandleContainerBase*,
const std::type_info*,
const std::type_info*);
} // namespace detail
template <typename TypeList, typename StorageList>
template <typename Functor>
VTKM_CONT void DynamicArrayHandleBase<TypeList, StorageList>::CastAndCall(const Functor& f) const
{
VTKM_IS_LIST_TAG(TypeList);
VTKM_IS_LIST_TAG(StorageList);
using TryTypeType = detail::DynamicArrayHandleTryType<Functor, StorageList>;
//For optimizations we should compile once the cross product for the default types
//and make it extern
using crossProduct = typename vtkm::ListCrossProduct<TypeList, StorageList>;
// We cast this to a DynamicArrayHandle because at this point we are ignoring
// the type/storage lists in it. There is no sense in adding more unnecessary
// template cases.
// The downside to this approach is that a copy is created, causing an
// atomic increment, which affects both performance and library size.
// For these reasons we have a specialization of this method to remove
// the copy when the type/storage lists are the default
DynamicArrayHandle t(*this);
TryTypeType tryType = TryTypeType(t, f);
auto* ptr = this->ArrayContainer.get();
bool called = false;
auto task = detail::ListFunctorWrapper<Functor>(called, f, ptr);
vtkm::ListForEach(tryType, TypeList());
if (!tryType.FoundCast)
vtkm::ListForEach(task, crossProduct{});
if (!called)
{
std::ostringstream out;
out << "Could not find appropriate cast for array in CastAndCall1.\n"
"Array: ";
this->PrintSummary(out);
out << "TypeList: " << typeid(TypeList).name()
<< "\nStorageList: " << typeid(StorageList).name() << "\n";
throw vtkm::cont::ErrorBadValue(out.str());
}
}
template <>
template <typename Functor>
VTKM_CONT void
DynamicArrayHandleBase<VTKM_DEFAULT_TYPE_LIST_TAG, VTKM_DEFAULT_STORAGE_LIST_TAG>::CastAndCall(
const Functor& f) const
{
using TryTypeType = detail::DynamicArrayHandleTryType<Functor, VTKM_DEFAULT_STORAGE_LIST_TAG>;
// We can remove the copy, as the current DynamicArrayHandle is already
// the default one, and no reason to do an atomic increment and increase
// library size, and reduce performance
TryTypeType tryType = TryTypeType(*this, f);
vtkm::ListForEach(tryType, VTKM_DEFAULT_TYPE_LIST_TAG());
if (!tryType.FoundCast)
{
throw vtkm::cont::ErrorBadValue("Could not find appropriate cast for array in CastAndCall2.");
// throw an exception
detail::ThrowCastAndCallException(ptr, &typeid(TypeList), &typeid(StorageList));
}
}
@ -549,6 +485,7 @@ struct DynamicTransformTraits<vtkm::cont::DynamicArrayHandleBase<TypeList, Stora
} // namespace internal
}
} // namespace vtkm::cont
#endif //vtk_m_cont_DynamicArrayHandle_h

@ -206,6 +206,35 @@ VTKM_CONT void ListForEachImpl(Functor&& f, brigand::list<T1, T2, T3, T4, ArgTyp
ListForEachImpl(f, brigand::list<ArgTypes...>());
}
template <typename T, typename U, typename R>
struct ListCrossProductAppend
{
using type = brigand::push_back<T, std::pair<U, R>>;
};
template <typename T, typename U, typename R2>
struct ListCrossProductImplUnrollR2
{
using P =
brigand::fold<R2,
brigand::list<>,
ListCrossProductAppend<brigand::_state, brigand::_element, brigand::pin<U>>>;
using type = brigand::append<T, P>;
};
template <typename R1, typename R2>
struct ListCrossProductImpl
{
using type = brigand::fold<
R2,
brigand::list<>,
ListCrossProductImplUnrollR2<brigand::_state, brigand::_element, brigand::pin<R1>>>;
};
} // namespace detail
//-----------------------------------------------------------------------------

@ -33,10 +33,6 @@ namespace
template <int N>
struct TestClass
{
enum
{
NUMBER = N
};
};
struct TestListTag1 : vtkm::ListTagBase<TestClass<11>>
@ -63,36 +59,57 @@ struct TestListTagIntersect : vtkm::ListTagIntersect<TestListTag3, TestListTagJo
{
};
struct TestListTagCrossProduct : vtkm::ListCrossProduct<TestListTag3, TestListTag1>
{
};
struct TestListTagUniversal : vtkm::ListTagUniversal
{
};
template <int N, int M>
std::pair<int, int> test_number(std::pair<TestClass<N>, TestClass<M>>)
{
return std::make_pair(N, M);
}
template <int N>
int test_number(TestClass<N>)
{
return N;
}
template <typename T>
struct MutableFunctor
{
std::vector<int> FoundTypes;
std::vector<T> FoundTypes;
template <typename T>
VTKM_CONT void operator()(T)
template <typename U>
VTKM_CONT void operator()(U u)
{
this->FoundTypes.push_back(T::NUMBER);
this->FoundTypes.push_back(test_number(u));
}
};
std::vector<int> g_FoundType;
template <typename T>
struct ConstantFunctor
{
ConstantFunctor() { g_FoundType.erase(g_FoundType.begin(), g_FoundType.end()); }
std::vector<T>& FoundTypes;
template <typename T>
VTKM_CONT void operator()(T) const
ConstantFunctor(std::vector<T>& values)
: FoundTypes(values)
{
g_FoundType.push_back(T::NUMBER);
}
template <typename U>
VTKM_CONT void operator()(U u) const
{
this->FoundTypes.push_back(test_number(u));
}
};
template <vtkm::IdComponent N>
void CheckSame(const vtkm::Vec<int, N>& expected, const std::vector<int>& found)
template <typename T, vtkm::IdComponent N>
void CheckSame(const vtkm::Vec<T, N>& expected, const std::vector<T>& found)
{
VTKM_TEST_ASSERT(static_cast<int>(found.size()) == N, "Got wrong number of items.");
@ -137,13 +154,15 @@ void TryList(const vtkm::Vec<int, N>& expected, ListTag)
VTKM_IS_LIST_TAG(ListTag);
std::cout << " Try mutable for each" << std::endl;
MutableFunctor functor;
MutableFunctor<int> functor;
vtkm::ListForEach(functor, ListTag());
CheckSame(expected, functor.FoundTypes);
std::cout << " Try constant for each" << std::endl;
vtkm::ListForEach(ConstantFunctor(), ListTag());
CheckSame(expected, g_FoundType);
std::vector<int> foundTypes;
ConstantFunctor<int> cfunc(foundTypes);
vtkm::ListForEach(cfunc, ListTag());
CheckSame(expected, foundTypes);
std::cout << " Try checking contents" << std::endl;
CheckContains(TestClass<11>(), ListTag(), functor.FoundTypes);
@ -157,6 +176,22 @@ void TryList(const vtkm::Vec<int, N>& expected, ListTag)
CheckContains(TestClass<43>(), ListTag(), functor.FoundTypes);
CheckContains(TestClass<44>(), ListTag(), functor.FoundTypes);
}
template <vtkm::IdComponent N, typename ListTag>
void TryList(const vtkm::Vec<std::pair<int, int>, N>& expected, ListTag)
{
VTKM_IS_LIST_TAG(ListTag);
std::cout << " Try mutable for each" << std::endl;
MutableFunctor<std::pair<int, int>> functor;
vtkm::ListForEach(functor, ListTag());
CheckSame(expected, functor.FoundTypes);
std::cout << " Try constant for each" << std::endl;
std::vector<std::pair<int, int>> foundTypes;
ConstantFunctor<std::pair<int, int>> cfunc(foundTypes);
vtkm::ListForEach(cfunc, ListTag());
CheckSame(expected, foundTypes);
}
template <vtkm::IdComponent N>
void TryList(const vtkm::Vec<int, N>&, TestListTagUniversal tag)
@ -207,6 +242,12 @@ void TestLists()
std::cout << "ListTagIntersect" << std::endl;
TryList(vtkm::Vec<int, 3>(31, 32, 33), TestListTagIntersect());
std::cout << "ListTagCrossProduct" << std::endl;
TryList(vtkm::Vec<std::pair<int, int>, 3>({ 31, 11 }, { 32, 11 }, { 33, 11 }),
TestListTagCrossProduct());
std::cout << "ListTagUniversal" << std::endl;
TryList(vtkm::Vec<int, 4>(1, 2, 3, 4), TestListTagUniversal());
}