vtk-m/vtkm/internal/VariantImplDetail.h.in
Kenneth Moreland 6afff75501 Modify Variant CastAndCall to have fewer cases in its switch
The previous implementation of `Variant`'s `CastAndCall` generated a
switch statement with 20 cases (plus a default) regardless of how many
types were handled by the `Variant` (with the excess doing nothing
useful). This reduced the amount of code, but caused the compiler to
have to build many more instructions (and optimize for them). This in
turn lead to large compile times and unnecessary large libraries/
executables.

This change makes a different function to use for `CastAndCall` so that
the number of cases in the switch matches exactly the number of types in
the `Variant`'s union.

Because the size of VariantImplDetail.h was getting large, I also
reduced the maximum expansions for the code. This does not seem to
negatively affect compile time, and I doubt it will have an noticible
difference in running time (when in release mode).

I also modified some other parts of this code to match the expansion
without making unnecessary defaults.
2022-03-28 14:14:06 -06:00

337 lines
12 KiB
C

//============================================================================
// Copyright (c) Kitware, Inc.
// All rights reserved.
// See LICENSE.txt for details.
//
// This software is distributed WITHOUT ANY WARRANTY; without even
// the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
// PURPOSE. See the above copyright notice for more information.
//============================================================================
$# This file uses the pyexpander macro processing utility to build the
$# FunctionInterface facilities that use a variable number of arguments.
$# Information, documentation, and downloads for pyexpander can be found at:
$#
$# http://pyexpander.sourceforge.net/
$#
$# To build the source code, execute the following (after installing
$# pyexpander, of course):
$#
$# expander.py VariantDetail.h.in > VariantDetail.h
$#
$# Ignore the following comment. It is meant for the generated file.
// **** DO NOT EDIT THIS FILE!!! ****
// This file is automatically generated by VariantDetail.h.in
#if !defined(VTK_M_DEVICE) || !defined(VTK_M_NAMESPACE)
#error VarianImplDetail.h must be included from VariantImpl.h
// Some defines to make my IDE happy.
#define VTK_M_DEVICE
#define VTK_M_NAMESPACE tmp
#endif
#include <vtkm/List.h>
#include <vtkm/Types.h>
#include <vtkm/internal/Assume.h>
#include <vtkmstd/is_trivial.h>
#include <type_traits>
$py(max_expanded=8)\
$# Python commands used in template expansion.
$py(
def type_list(num_params):
if num_params < 0:
return ''
result = 'T0'
for param in range(1, num_params + 1):
result += ', T%d' % param
return result
def typename_list(num_params):
if num_params < 0:
return ''
result = 'typename T0'
for param in range(1, num_params + 1):
result += ', typename T%d' % param
return result
)\
$#
$extend(type_list, typename_list)\
namespace vtkm
{
namespace VTK_M_NAMESPACE
{
namespace internal
{
namespace detail
{
// --------------------------------------------------------------------------------
// Helper classes to determine if all Variant types are trivial.
template <typename... Ts>
using AllTriviallyCopyable = vtkm::ListAll<vtkm::List<Ts...>, vtkmstd::is_trivially_copyable>;
// Single argument version of is_trivially_constructible
template <typename T>
using Constructible = vtkmstd::is_trivially_constructible<T>;
template <typename... Ts>
using AllTriviallyConstructible = vtkm::ListAll<vtkm::List<Ts...>, Constructible>;
template <typename... Ts>
using AllTriviallyDestructible =
vtkm::ListAll<vtkm::List<Ts...>, vtkmstd::is_trivially_destructible>;
// clang-format off
// --------------------------------------------------------------------------------
// Union type used inside of Variant
//
// You may be asking yourself, why not just use an std::aligned_union rather than a real union
// type? That was our first implementation, but the problem is that the std::aligned_union
// reference needs to be recast to the actual type. Typically you would do that with
// reinterpret_cast. However, doing that leads to undefined behavior. The C++ compiler assumes that
// 2 pointers of different types point to different memory (even if it is clear that they are set
// to the same address). That means optimizers can remove code because it "knows" that data in one
// type cannot affect data in another type. (See Shafik Yaghmour's excellent writeup at
// https://gist.github.com/shafik/848ae25ee209f698763cffee272a58f8 for more details.) To safely
// change the type of an std::aligned_union, you really have to do an std::memcpy. This is
// problematic for types that cannot be trivially copied. Another problem is that we found that
// device compilers do not optimize the memcpy as well as most CPU compilers. Likely, memcpy is
// used much less frequently on GPU devices.
//
// Part of the trickiness of the union implementation is trying to preserve when the type is
// trivially constructible and copyable. The trick is that if members of the union are not trivial,
// then the default constructors are deleted. To get around that, a non-default constructor is
// added, which we can use to construct the union for non-trivial types. Working with types with
// non-trivial destructors are particularly tricky. Again, if any member of the union has a
// non-trivial destructor, the destructor is deleted. Unlike a constructor, you cannot just say to
// use a different destructor. Thus, we have to define our own destructor for the union.
// Technically, the destructor here does not do anything, but the actual destruction should be
// handled by the Variant class that contains this VariantUnion. We actually need two separate
// implementations of our union, one that defines a destructor and one that use the default
// destructor. If you define your own destructor, you can lose the trivial constructor and trivial
// copy properties.
//
// TD = trivially deconstructible
template <typename T0, typename... Ts>
union VariantUnionTD;
// NTD = non-trivially deconstructible
template <typename T0, typename... Ts>
union VariantUnionNTD;
$for(param_length in range(max_expanded))\
template <$typename_list(param_length)>
union VariantUnionTD<$type_list(param_length)>
{
$for(param_index in range(param_length + 1))\
T$(param_index) V$(param_index);
$endfor\
VTK_M_DEVICE VariantUnionTD(vtkm::internal::NullType) { }
VariantUnionTD() = default;
};
template <$typename_list(param_length)>
union VariantUnionNTD<$type_list(param_length)>
{
$for(param_index in range(param_length + 1))\
T$(param_index) V$(param_index);
$endfor\
VTK_M_DEVICE VariantUnionNTD(vtkm::internal::NullType) { }
VariantUnionNTD() = default;
VTK_M_DEVICE ~VariantUnionNTD() { }
};
$endfor\
template <$typename_list(max_expanded), typename... Ts>
union VariantUnionTD<$type_list(max_expanded), Ts...>
{
$for(param_index in range(max_expanded))\
T$(param_index) V$(param_index);
$endfor\
VariantUnionTD<T$(max_expanded), Ts...> Remaining;
VTK_M_DEVICE VariantUnionTD(vtkm::internal::NullType) { }
VariantUnionTD() = default;
};
template <$typename_list(max_expanded), typename... Ts>
union VariantUnionNTD<$type_list(max_expanded), Ts...>
{
$for(param_index in range(max_expanded))\
T$(param_index) V$(param_index);
$endfor\
VariantUnionNTD<T$(max_expanded), Ts...> Remaining;
VTK_M_DEVICE VariantUnionNTD(vtkm::internal::NullType) { }
VariantUnionNTD() = default;
VTK_M_DEVICE ~VariantUnionNTD() { }
};
//clang-format on
template <bool TrivialConstructor, typename... Ts>
struct VariantUnionFinder;
template <typename... Ts>
struct VariantUnionFinder<true, Ts...>
{
using type = VariantUnionTD<Ts...>;
};
template <typename... Ts>
struct VariantUnionFinder<false, Ts...>
{
using type = VariantUnionNTD<Ts...>;
};
template <typename... Ts>
using VariantUnion =
typename VariantUnionFinder<AllTriviallyDestructible<Ts...>::value, Ts...>::type;
// --------------------------------------------------------------------------------
// Methods to get values out of the variant union
template <vtkm::IdComponent I, typename UnionType>
struct VariantUnionGetImpl;
$for(param_index in range(max_expanded))\
template <typename UnionType>
struct VariantUnionGetImpl<$(param_index), UnionType>
{
using ReturnType = decltype(std::declval<UnionType>().V$(param_index));
VTK_M_DEVICE static ReturnType& Get(UnionType& storage) noexcept
{
return storage.V$(param_index);
}
VTK_M_DEVICE static const ReturnType& Get(const UnionType& storage) noexcept
{
return storage.V$(param_index);
}
};
$endfor\
template <vtkm::IdComponent I, typename UnionType>
struct VariantUnionGetImpl
{
VTKM_STATIC_ASSERT(I >= $(max_expanded));
using RecursiveGet = VariantUnionGetImpl<I - $(max_expanded), decltype(std::declval<UnionType&>().Remaining)>;
using ReturnType = typename RecursiveGet::ReturnType;
VTK_M_DEVICE static ReturnType& Get(UnionType& storage) noexcept
{
return RecursiveGet::Get(storage.Remaining);
}
VTK_M_DEVICE static const ReturnType& Get(const UnionType& storage) noexcept
{
return RecursiveGet::Get(storage.Remaining);
}
};
template <vtkm::IdComponent I, typename UnionType>
VTK_M_DEVICE auto VariantUnionGet(UnionType& storage) noexcept
-> decltype(VariantUnionGetImpl<I, typename std::decay<UnionType>::type>::Get(storage))&
{
return VariantUnionGetImpl<I, typename std::decay<UnionType>::type>::Get(storage);
}
// --------------------------------------------------------------------------------
// Internal implementation of CastAndCall for Variant
template <std::size_t NumCases>
struct VariantCases
{
template <typename Functor, typename UnionType, typename... Args>
VTK_M_DEVICE static inline auto CastAndCall(
vtkm::IdComponent index,
Functor&& f,
UnionType& storage,
Args&&... args) noexcept(noexcept(f(storage.V0, args...)))
-> decltype(f(storage.V0, args...))
{
VTKM_ASSERT((index >= 0) && (index < static_cast<vtkm::IdComponent>(NumCases)));
switch (index)
{
$for(param_index in range(max_expanded))\
case $(param_index):
// If you get a compile error here, it probably means that you have called
// Variant::CastAndCall with a functor that does not accept one of the types in the
// Variant. The functor you provide must be callable with all types in the Variant, not
// just the one that it currently holds.
return f(storage.V$(param_index), std::forward<Args>(args)...);
$endfor\
default:
return VariantCases<NumCases - $(max_expanded)>::template CastAndCall(
index - $(max_expanded), std::forward<Functor>(f), storage.Remaining, std::forward<Args>(args)...);
}
}
};
template<>
struct VariantCases<1>
{
template <typename Functor, typename UnionType, typename... Args>
VTK_M_DEVICE static inline auto CastAndCall(
vtkm::IdComponent index,
Functor&& f,
UnionType& storage,
Args&&... args) noexcept(noexcept(f(storage.V0, args...)))
-> decltype(f(storage.V0, args...))
{
// Assume index is 0. Saves us some conditionals.
VTKM_ASSERT(index == 0);
(void)index;
return f(storage.V0, std::forward<Args>(args)...);
}
};
$for(case_index in range(2, max_expanded + 1))\
template<>
struct VariantCases<$(case_index)>
{
template <typename Functor, typename UnionType, typename... Args>
VTK_M_DEVICE static inline auto CastAndCall(
vtkm::IdComponent index,
Functor&& f,
UnionType& storage,
Args&&... args) noexcept(noexcept(f(storage.V0, args...)))
-> decltype(f(storage.V0, args...))
{
// Assume index is 0. Saves us some conditionals.
VTKM_ASSERT((index >= 0) && (index < $(case_index)));
switch (index)
{
default:
$for(param_index in range(case_index))\
case $(param_index):
// If you get a compile error here, it probably means that you have called
// Variant::CastAndCall with a functor that does not accept one of the types in the
// Variant. The functor you provide must be callable with all types in the Variant, not
// just the one that it currently holds.
return f(storage.V$(param_index), std::forward<Args>(args)...);
$endfor\
}
}
};
$endfor
template <std::size_t UnionSize, typename Functor, typename UnionType, typename... Args>
VTK_M_DEVICE inline auto VariantCastAndCallImpl(
vtkm::IdComponent index,
Functor&& f,
UnionType& storage,
Args&&... args) noexcept(noexcept(f(storage.V0, args...)))
-> decltype(f(storage.V0, args...))
{
return VariantCases<UnionSize>::CastAndCall(
index, std::forward<Functor>(f), storage, std::forward<Args>(args)...);
}
}
}
}
} // vtkm::VTK_M_NAMESPACE::internal::detail