Add ability to specialize worklet for device

This adds an ExecutionSignature tag named Device that passes the
DeviceAdapterTag as an argument to the worklet's operator(). This allows
worklets to specialize their code based on the device.
This commit is contained in:
Kenneth Moreland 2019-03-07 18:32:56 -07:00
parent e787d52afc
commit 4e34feecb4
13 changed files with 1538 additions and 2343 deletions

@ -0,0 +1,147 @@
# Add ability to specialize a worklet for a device
This change adds an execution signature tag named `Device` that passes
a `DeviceAdapterTag` to the worklet's parenthesis operator. This allows the
worklet to specialize its operation. This features is available in all
worklets.
The following example shows a worklet that specializes itself for the CUDA
device.
```cpp
struct DeviceSpecificWorklet : vtkm::worklet::WorkletMapField
{
using ControlSignature = void(FieldIn, FieldOut);
using ExecutionSignature = _2(_1, Device);
// Specialization for the Cuda device.
template <typename T>
T operator()(T x, vtkm::cont::DeviceAdapterTagCuda) const
{
// Special cuda implementation
}
// General implementation
template <typename T, typename Device>
T operator()(T x, Device) const
{
// General implementation
}
};
```
## Effect on compile time and binary size
This change necessitated adding a template parameter for the device that
followed at least from the schedule all the way down. This has the
potential for duplicating several of the support methods (like
`DoWorkletInvokeFunctor`) that would otherwise have the same type. This is
especially true between the devices that run on the CPU as they should all
be sharing the same portals from `ArrayHandle`s. So the question is whether
it causes compile to take longer or cause a significant increase in
binaries.
To informally test, I first ran a clean debug compile on my Windows machine
with the serial and tbb devices. The build itself took **3 minutes, 50
seconds**. Here is a list of the binary sizes in the bin directory:
```
kmorel2 0> du -sh *.exe *.dll
200K BenchmarkArrayTransfer_SERIAL.exe
204K BenchmarkArrayTransfer_TBB.exe
424K BenchmarkAtomicArray_SERIAL.exe
424K BenchmarkAtomicArray_TBB.exe
440K BenchmarkCopySpeeds_SERIAL.exe
580K BenchmarkCopySpeeds_TBB.exe
4.1M BenchmarkDeviceAdapter_SERIAL.exe
5.3M BenchmarkDeviceAdapter_TBB.exe
7.9M BenchmarkFieldAlgorithms_SERIAL.exe
7.9M BenchmarkFieldAlgorithms_TBB.exe
22M BenchmarkFilters_SERIAL.exe
22M BenchmarkFilters_TBB.exe
276K BenchmarkRayTracing_SERIAL.exe
276K BenchmarkRayTracing_TBB.exe
4.4M BenchmarkTopologyAlgorithms_SERIAL.exe
4.4M BenchmarkTopologyAlgorithms_TBB.exe
712K Rendering_SERIAL.exe
712K Rendering_TBB.exe
708K UnitTests_vtkm_cont_arg_testing.exe
1.7M UnitTests_vtkm_cont_internal_testing.exe
13M UnitTests_vtkm_cont_serial_testing.exe
14M UnitTests_vtkm_cont_tbb_testing.exe
18M UnitTests_vtkm_cont_testing.exe
13M UnitTests_vtkm_cont_testing_mpi.exe
736K UnitTests_vtkm_exec_arg_testing.exe
136K UnitTests_vtkm_exec_internal_testing.exe
196K UnitTests_vtkm_exec_serial_internal_testing.exe
196K UnitTests_vtkm_exec_tbb_internal_testing.exe
2.0M UnitTests_vtkm_exec_testing.exe
83M UnitTests_vtkm_filter_testing.exe
476K UnitTests_vtkm_internal_testing.exe
148K UnitTests_vtkm_interop_internal_testing.exe
1.3M UnitTests_vtkm_interop_testing.exe
2.9M UnitTests_vtkm_io_reader_testing.exe
548K UnitTests_vtkm_io_writer_testing.exe
792K UnitTests_vtkm_rendering_testing.exe
3.7M UnitTests_vtkm_testing.exe
320K UnitTests_vtkm_worklet_internal_testing.exe
65M UnitTests_vtkm_worklet_testing.exe
11M vtkm_cont-1.3.dll
2.1M vtkm_interop-1.3.dll
21M vtkm_rendering-1.3.dll
3.9M vtkm_worklet-1.3.dll
```
After making the singular change to the `Invocation` object to add the
`DeviceAdapterTag` as a template parameter (which should cause any extra
compile instances) the compile took **4 minuts and 5 seconds**. Here is the
new list of binaries.
```
kmorel2 0> du -sh *.exe *.dll
200K BenchmarkArrayTransfer_SERIAL.exe
204K BenchmarkArrayTransfer_TBB.exe
424K BenchmarkAtomicArray_SERIAL.exe
424K BenchmarkAtomicArray_TBB.exe
440K BenchmarkCopySpeeds_SERIAL.exe
580K BenchmarkCopySpeeds_TBB.exe
4.1M BenchmarkDeviceAdapter_SERIAL.exe
5.3M BenchmarkDeviceAdapter_TBB.exe
7.9M BenchmarkFieldAlgorithms_SERIAL.exe
7.9M BenchmarkFieldAlgorithms_TBB.exe
22M BenchmarkFilters_SERIAL.exe
22M BenchmarkFilters_TBB.exe
276K BenchmarkRayTracing_SERIAL.exe
276K BenchmarkRayTracing_TBB.exe
4.4M BenchmarkTopologyAlgorithms_SERIAL.exe
4.4M BenchmarkTopologyAlgorithms_TBB.exe
712K Rendering_SERIAL.exe
712K Rendering_TBB.exe
708K UnitTests_vtkm_cont_arg_testing.exe
1.7M UnitTests_vtkm_cont_internal_testing.exe
13M UnitTests_vtkm_cont_serial_testing.exe
14M UnitTests_vtkm_cont_tbb_testing.exe
19M UnitTests_vtkm_cont_testing.exe
13M UnitTests_vtkm_cont_testing_mpi.exe
736K UnitTests_vtkm_exec_arg_testing.exe
136K UnitTests_vtkm_exec_internal_testing.exe
196K UnitTests_vtkm_exec_serial_internal_testing.exe
196K UnitTests_vtkm_exec_tbb_internal_testing.exe
2.0M UnitTests_vtkm_exec_testing.exe
86M UnitTests_vtkm_filter_testing.exe
476K UnitTests_vtkm_internal_testing.exe
148K UnitTests_vtkm_interop_internal_testing.exe
1.3M UnitTests_vtkm_interop_testing.exe
2.9M UnitTests_vtkm_io_reader_testing.exe
548K UnitTests_vtkm_io_writer_testing.exe
792K UnitTests_vtkm_rendering_testing.exe
3.7M UnitTests_vtkm_testing.exe
320K UnitTests_vtkm_worklet_internal_testing.exe
68M UnitTests_vtkm_worklet_testing.exe
11M vtkm_cont-1.3.dll
2.1M vtkm_interop-1.3.dll
21M vtkm_rendering-1.3.dll
3.9M vtkm_worklet-1.3.dll
```
So far the increase is quite negligible.

@ -737,8 +737,14 @@ struct Serialization<vtkm::cont::ArrayHandle<T>>
} // diy
#ifndef vtk_m_cont_ArrayHandle_hxx
#include <vtkm/cont/ArrayHandle.hxx>
#endif
#ifndef vtk_m_cont_internal_ArrayHandleBasicImpl_h
#include <vtkm/cont/internal/ArrayHandleBasicImpl.h>
#endif
#include <vtkm/cont/internal/ArrayExportMacros.h>
#ifndef vtkm_cont_ArrayHandle_cxx

@ -17,6 +17,10 @@
// Laboratory (LANL), the U.S. Government retains certain rights in
// this software.
//============================================================================
#ifndef vtk_m_cont_ArrayHandle_hxx
#define vtk_m_cont_ArrayHandle_hxx
#include <vtkm/cont/ArrayHandle.h>
namespace vtkm
{
@ -425,3 +429,5 @@ VTKM_CONT void Serialization<vtkm::cont::ArrayHandle<T>>::load(BinaryBuffer& bb,
}
}
} // diy
#endif //vtk_m_cont_ArrayHandle_hxx

@ -20,6 +20,7 @@
#define vtkm_cont_internal_ArrayHandleImpl_cxx
#include <vtkm/cont/ArrayHandle.h>
#include <vtkm/cont/internal/ArrayHandleBasicImpl.h>
namespace vtkm

@ -18,11 +18,11 @@
// this software.
//============================================================================
#include <vtkm/cont/ArrayHandle.h>
#ifndef vtk_m_cont_internal_ArrayHandleBasicImpl_h
#define vtk_m_cont_internal_ArrayHandleBasicImpl_h
#include <vtkm/cont/ArrayHandle.h>
#include <vtkm/cont/StorageBasic.h>
#include <type_traits>
@ -57,19 +57,21 @@ struct ExecutionPortalFactoryBasic
#ifndef VTKM_DOXYGEN_ONLY
;
#else // VTKM_DOXYGEN_ONLY
{
/// The portal type.
using PortalType = SomePortalType;
/// The cont portal type.
using ConstPortalType = SomePortalType;
/// The cont portal type.
using ConstPortalType = SomePortalType;
/// Create a portal to access the execution data from @a start to @a end.
VTKM_CONT
static PortalType CreatePortal(ValueType* start, ValueType* end);
/// Create a portal to access the execution data from @a start to @a end.
VTKM_CONT
static PortalType CreatePortal(ValueType* start, ValueType* end);
/// Create a const portal to access the execution data from @a start to @a end.
VTKM_CONT
static PortalConstType CreatePortalConst(const ValueType* start, const ValueType* end);
/// Create a const portal to access the execution data from @a start to @a end.
VTKM_CONT
static PortalConstType CreatePortalConst(const ValueType* start, const ValueType* end);
};
#endif // VTKM_DOXYGEN_ONLY
/// Typeless interface for interacting with a execution memory buffer when using basic storage.
@ -233,6 +235,7 @@ public:
template <typename DeviceTag>
struct ExecutionTypes
{
VTKM_IS_DEVICE_ADAPTER_TAG(DeviceTag);
using Portal = typename PortalFactory<DeviceTag>::PortalType;
using PortalConst = typename PortalFactory<DeviceTag>::PortalConstType;
};
@ -302,6 +305,8 @@ extern template class VTKM_CONT_TEMPLATE_EXPORT
#endif
#endif
#ifndef vtk_m_cont_internal_ArrayHandleBasicImpl_hxx
#include <vtkm/cont/internal/ArrayHandleBasicImpl.hxx>
#endif
#endif // vtk_m_cont_internal_ArrayHandleBasicImpl_h

File diff suppressed because it is too large Load Diff

@ -51,6 +51,7 @@ $# Ignore the following comment. It is meant for the generated file.
#include <vtkm/internal/Invocation.h>
#include <vtkm/exec/arg/Fetch.h>
#include <vtkm/exec/arg/FetchTagExecObject.h>
$# This needs to match the max_parameters in FunctionInterfaceDetailPre.h.in
$py(max_parameters=20)\
@ -115,6 +116,11 @@ namespace internal
namespace detail
{
struct DummyDeviceControlSignatureTag
{
using FetchTag = vtkm::exec::arg::FetchTagExecObject;
};
/// A helper class that takes an \c Invocation object and an index to a
/// parameter in the ExecutionSignature and finds the \c Fetch type valid for
/// that parameter.
@ -133,15 +139,38 @@ struct InvocationToFetch
using AspectTag = typename ExecutionSignatureTag::AspectTag;
// Find the fetch tag from the control signature tag pointed to by
// ParameterIndex.
// ParameterIndex. Note that ControlParameterIndex of 0 is reserved
// for getting the device adapter tag.
using ControlInterface = typename Invocation::ControlInterface;
using ControlSignatureTag = typename ControlInterface::template ParameterType<ControlParameterIndex>::type;
using ControlSignatureTag =
typename std::conditional<
ControlParameterIndex == 0,
DummyDeviceControlSignatureTag,
typename ControlInterface::template ParameterType<ControlParameterIndex>::type>::type;
using FetchTag = typename ControlSignatureTag::FetchTag;
using ExecObjectType =
typename Invocation::ParameterInterface::template ParameterType<ControlParameterIndex>::type;
typename std::conditional<
ControlParameterIndex == 0,
typename Invocation::DeviceAdapterTag,
typename Invocation::ParameterInterface::template ParameterType<ControlParameterIndex>::type>::type;
using type = vtkm::exec::arg::Fetch<FetchTag, AspectTag, ThreadIndicesType, ExecObjectType>;
VTKM_EXEC static ExecObjectType GetParameterImpl(const Invocation&, std::true_type)
{
return typename Invocation::DeviceAdapterTag();
}
VTKM_EXEC static ExecObjectType GetParameterImpl(const Invocation& invocation, std::false_type)
{
return invocation.Parameters.template GetParameter<ControlParameterIndex>();
}
VTKM_EXEC static ExecObjectType GetParameter(const Invocation& invocation)
{
return GetParameterImpl(invocation, std::integral_constant<bool, ControlParameterIndex == 0>());
}
};
// clang-format off
@ -154,6 +183,7 @@ template <typename WorkletType,
typename OutputToInputMapType,
typename VisitArrayType,
typename ThreadToOutputMapType,
typename DeviceAdapterTag,
typename ThreadIndicesType,
$template_params(num_params)>
VTKM_EXEC void DoWorkletInvokeFunctor(
@ -163,7 +193,9 @@ VTKM_EXEC void DoWorkletInvokeFunctor(
vtkm::internal::FunctionInterface<$signature(num_params)>,
InputDomainIndex,
OutputToInputMapType,
VisitArrayType, ThreadToOutputMapType>& invocation,
VisitArrayType,
ThreadToOutputMapType,
DeviceAdapterTag>& invocation,
const ThreadIndicesType& threadIndices)
{
using Invocation = vtkm::internal::Invocation<ParameterInterface,
@ -171,15 +203,16 @@ VTKM_EXEC void DoWorkletInvokeFunctor(
vtkm::internal::FunctionInterface<$signature(num_params)>,
InputDomainIndex,
OutputToInputMapType,
VisitArrayType>;
VisitArrayType,
ThreadToOutputMapType,
DeviceAdapterTag>;
$for(param_index in range(1, num_params+1))\
using FetchInfo$(param_index) = InvocationToFetch<ThreadIndicesType, Invocation, $(param_index)>;
using FetchType$(param_index) = typename FetchInfo$(param_index)::type;
FetchType$(param_index) fetch$(param_index);
typename FetchType$(param_index)::ValueType $pname(param_index) =
fetch$(param_index).Load(threadIndices,
invocation.Parameters.template GetParameter<FetchInfo$(param_index)::ControlParameterIndex>());
fetch$(param_index).Load(threadIndices, FetchInfo$(param_index)::GetParameter(invocation));
$endfor\
using FetchInfo0 = InvocationToFetch<ThreadIndicesType, Invocation, 0>;
@ -197,15 +230,10 @@ $endfor\
auto $pname(0) = typename ReturnFetchType::ValueType(worklet($arg_list(num_params)));
$for(param_index in range(1, num_params+1))\
fetch$(param_index).Store(threadIndices,
invocation.Parameters.template GetParameter<FetchInfo$(param_index)::ControlParameterIndex>(),
$pname(param_index));
fetch$(param_index).Store(threadIndices, FetchInfo$(param_index)::GetParameter(invocation), $pname(param_index));
$endfor\
returnFetch.Store(
threadIndices,
invocation.Parameters.template GetParameter<FetchInfo0::ControlParameterIndex>(),
$pname(0));
returnFetch.Store(threadIndices, FetchInfo0::GetParameter(invocation), $pname(0));
}
template <typename WorkletType,
@ -213,7 +241,9 @@ template <typename WorkletType,
typename ControlInterface,
vtkm::IdComponent InputDomainIndex,
typename OutputToInputMapType,
typename VisitArrayType, typename ThreadToOutputMapType,
typename VisitArrayType,
typename ThreadToOutputMapType,
typename DeviceAdapterTag,
typename ThreadIndicesType,
$template_params(num_params, start=1)>
VTKM_EXEC void DoWorkletInvokeFunctor(
@ -223,7 +253,9 @@ VTKM_EXEC void DoWorkletInvokeFunctor(
vtkm::internal::FunctionInterface<$signature(num_params, return_type='void')>,
InputDomainIndex,
OutputToInputMapType,
VisitArrayType, ThreadToOutputMapType>& invocation,
VisitArrayType,
ThreadToOutputMapType,
DeviceAdapterTag>& invocation,
const ThreadIndicesType& threadIndices)
{
using Invocation =
@ -232,15 +264,16 @@ VTKM_EXEC void DoWorkletInvokeFunctor(
vtkm::internal::FunctionInterface<$signature(num_params, return_type='void')>,
InputDomainIndex,
OutputToInputMapType,
VisitArrayType>;
VisitArrayType,
ThreadToOutputMapType,
DeviceAdapterTag>;
$for(param_index in range(1, num_params+1))\
using FetchInfo$(param_index) = InvocationToFetch<ThreadIndicesType, Invocation, $(param_index)>;
using FetchType$(param_index) = typename FetchInfo$(param_index)::type;
FetchType$(param_index) fetch$(param_index);
typename FetchType$(param_index)::ValueType $pname(param_index) =
fetch$(param_index).Load(threadIndices,
invocation.Parameters.template GetParameter<FetchInfo$(param_index)::ControlParameterIndex>());
fetch$(param_index).Load(threadIndices, FetchInfo$(param_index)::GetParameter(invocation));
$endfor\
// If you got a compile error on the following line, it probably means that
@ -255,9 +288,7 @@ $endfor\
worklet($arg_list(num_params));
$for(param_index in range(1, num_params+1))\
fetch$(param_index).Store(threadIndices,
invocation.Parameters.template GetParameter<FetchInfo$(param_index)::ControlParameterIndex>(),
$pname(param_index));
fetch$(param_index).Store(threadIndices, FetchInfo$(param_index)::GetParameter(invocation), $pname(param_index));
$endfor\
}

@ -39,7 +39,8 @@ template <typename ParameterInterface_,
vtkm::IdComponent InputDomainIndex_,
typename OutputToInputMapType_ = vtkm::internal::NullType,
typename VisitArrayType_ = vtkm::internal::NullType,
typename ThreadToOutputMapType_ = vtkm::internal::NullType>
typename ThreadToOutputMapType_ = vtkm::internal::NullType,
typename DeviceAdapterTag_ = vtkm::internal::NullType>
struct Invocation
{
/// \brief The types of the parameters
@ -99,6 +100,13 @@ struct Invocation
///
using ThreadToOutputMapType = ThreadToOutputMapType_;
/// \brief The tag for the device adapter on which the worklet is run.
///
/// When the worklet is dispatched on a particular device, this type in the
/// Invocation is set to the tag associated with that device.
///
using DeviceAdapterTag = DeviceAdapterTag_;
/// \brief Default Invocation constructors that holds the given parameters
/// by reference.
VTKM_CONT
@ -125,7 +133,8 @@ struct Invocation
InputDomainIndex,
OutputToInputMapType,
VisitArrayType,
ThreadToOutputMapType>;
ThreadToOutputMapType,
DeviceAdapterTag>;
};
/// Returns a new \c Invocation that is the same as this one except that the
@ -151,7 +160,8 @@ struct Invocation
InputDomainIndex,
OutputToInputMapType,
VisitArrayType,
ThreadToOutputMapType>;
ThreadToOutputMapType,
DeviceAdapterTag>;
};
/// Returns a new \c Invocation that is the same as this one except that the
@ -177,7 +187,8 @@ struct Invocation
InputDomainIndex,
OutputToInputMapType,
VisitArrayType,
ThreadToOutputMapType>;
ThreadToOutputMapType,
DeviceAdapterTag>;
};
/// Returns a new \c Invocation that is the same as this one except that the
@ -203,7 +214,8 @@ struct Invocation
NewInputDomainIndex,
OutputToInputMapType,
VisitArrayType,
ThreadToOutputMapType>;
ThreadToOutputMapType,
DeviceAdapterTag>;
};
/// Returns a new \c Invocation that is the same as this one except that the
@ -229,7 +241,8 @@ struct Invocation
InputDomainIndex,
NewOutputToInputMapType,
VisitArrayType,
ThreadToOutputMapType>;
ThreadToOutputMapType,
DeviceAdapterTag>;
};
/// Returns a new \c Invocation that is the same as this one except that the
@ -255,7 +268,8 @@ struct Invocation
InputDomainIndex,
OutputToInputMapType,
NewVisitArrayType,
ThreadToOutputMapType>;
ThreadToOutputMapType,
DeviceAdapterTag>;
};
/// Returns a new \c Invocation that is the same as this one except that the
@ -281,7 +295,8 @@ struct Invocation
InputDomainIndex,
OutputToInputMapType,
VisitArrayType,
NewThreadToOutputMapType>;
NewThreadToOutputMapType,
DeviceAdapterTag>;
};
/// Returns a new \c Invocation that is the same as this one except that the
@ -295,6 +310,33 @@ struct Invocation
this->Parameters, this->OutputToInputMap, this->VisitArray, newThreadToOutputMap);
}
/// Defines a new \c Invocation type that is the same as this type except
/// with the \c DeviceAdapterTag replaced.
///
template <typename NewDeviceAdapterTag>
struct ChangeDeviceAdapterTagType
{
using type = Invocation<ParameterInterface,
ControlInterface,
ExecutionInterface,
InputDomainIndex,
OutputToInputMapType,
VisitArrayType,
ThreadToOutputMapType,
NewDeviceAdapterTag>;
};
/// Returns a new \c Invocation that is the same as this one except that the
/// \c DeviceAdapterTag is replaced with that provided.
///
template <typename NewDeviceAdapterTag>
VTKM_CONT typename ChangeDeviceAdapterTagType<NewDeviceAdapterTag>::type ChangeDeviceAdapterTag(
NewDeviceAdapterTag) const
{
return typename ChangeDeviceAdapterTagType<NewDeviceAdapterTag>::type(
this->Parameters, this->OutputToInputMap, this->VisitArray, this->ThreadToOutputMap);
}
/// A convenience alias for the input domain type.
///
using InputDomainType =
@ -333,7 +375,8 @@ private:
InputDomainIndex,
OutputToInputMapType,
VisitArrayType,
ThreadToOutputMapType>&) = delete;
ThreadToOutputMapType,
DeviceAdapterTag>&) = delete;
};
/// Convenience function for creating an Invocation object.

@ -128,14 +128,14 @@ public:
}
}; //class createLeafs
template <typename Device>
template <typename DeviceAdapterTag>
class LinearBVHBuilder::GatherVecCast : public vtkm::worklet::WorkletMapField
{
private:
using Vec4IdArrayHandle = typename vtkm::cont::ArrayHandle<vtkm::Vec<vtkm::Id, 4>>;
using Vec4IntArrayHandle = typename vtkm::cont::ArrayHandle<vtkm::Vec<vtkm::Int32, 4>>;
using PortalConst = typename Vec4IdArrayHandle::ExecutionTypes<Device>::PortalConst;
using Portal = typename Vec4IntArrayHandle::ExecutionTypes<Device>::Portal;
using PortalConst = typename Vec4IdArrayHandle::ExecutionTypes<DeviceAdapterTag>::PortalConst;
using Portal = typename Vec4IntArrayHandle::ExecutionTypes<DeviceAdapterTag>::Portal;
private:
PortalConst InputPortal;
@ -146,9 +146,9 @@ public:
GatherVecCast(const Vec4IdArrayHandle& inputPortal,
Vec4IntArrayHandle& outputPortal,
const vtkm::Id& size)
: InputPortal(inputPortal.PrepareForInput(Device()))
: InputPortal(inputPortal.PrepareForInput(DeviceAdapterTag()))
{
this->OutputPortal = outputPortal.PrepareForOutput(size, Device());
this->OutputPortal = outputPortal.PrepareForOutput(size, DeviceAdapterTag());
}
using ControlSignature = void(FieldIn);
using ExecutionSignature = void(WorkIndex, _1);

@ -307,12 +307,12 @@ public:
} //namespace
template <typename Device, typename LocatorType>
template <typename DeviceAdapterTag, typename LocatorType>
class Sampler : public vtkm::worklet::WorkletMapField
{
private:
using ColorArrayHandle = typename vtkm::cont::ArrayHandle<vtkm::Vec<vtkm::Float32, 4>>;
using ColorArrayPortal = typename ColorArrayHandle::ExecutionTypes<Device>::PortalConst;
using ColorArrayPortal = typename ColorArrayHandle::ExecutionTypes<DeviceAdapterTag>::PortalConst;
ColorArrayPortal ColorMap;
vtkm::Id ColorMapSize;
vtkm::Float32 MinScalar;
@ -327,7 +327,7 @@ public:
const vtkm::Float32& maxScalar,
const vtkm::Float32& sampleDistance,
const LocatorType& locator)
: ColorMap(colorMap.PrepareForInput(Device()))
: ColorMap(colorMap.PrepareForInput(DeviceAdapterTag()))
, MinScalar(minScalar)
, SampleDistance(sampleDistance)
, InverseDeltaScalar(minScalar)
@ -500,12 +500,12 @@ public:
}
}; //Sampler
template <typename Device, typename LocatorType>
template <typename DeviceAdapterTag, typename LocatorType>
class SamplerCellAssoc : public vtkm::worklet::WorkletMapField
{
private:
using ColorArrayHandle = typename vtkm::cont::ArrayHandle<vtkm::Vec<vtkm::Float32, 4>>;
using ColorArrayPortal = typename ColorArrayHandle::ExecutionTypes<Device>::PortalConst;
using ColorArrayPortal = typename ColorArrayHandle::ExecutionTypes<DeviceAdapterTag>::PortalConst;
ColorArrayPortal ColorMap;
vtkm::Id ColorMapSize;
vtkm::Float32 MinScalar;
@ -520,7 +520,7 @@ public:
const vtkm::Float32& maxScalar,
const vtkm::Float32& sampleDistance,
const LocatorType& locator)
: ColorMap(colorMap.PrepareForInput(Device()))
: ColorMap(colorMap.PrepareForInput(DeviceAdapterTag()))
, MinScalar(minScalar)
, SampleDistance(sampleDistance)
, InverseDeltaScalar(minScalar)

@ -832,11 +832,15 @@ private:
}
template <typename Invocation, typename RangeType, typename DeviceAdapter>
VTKM_CONT void InvokeSchedule(const Invocation& invocation, RangeType range, DeviceAdapter) const
VTKM_CONT void InvokeSchedule(const Invocation& invocation,
RangeType range,
DeviceAdapter device) const
{
using Algorithm = vtkm::cont::DeviceAdapterAlgorithm<DeviceAdapter>;
using TaskTypes = typename vtkm::cont::DeviceTaskTypes<DeviceAdapter>;
auto invocationForDevice = invocation.ChangeDeviceAdapterTag(device);
// The TaskType class handles the magic of fetching values
// for each instance and calling the worklet's function.
// The TaskType will evaluate to one of the following classes:
@ -844,7 +848,7 @@ private:
// vtkm::exec::internal::TaskSingular
// vtkm::exec::internal::TaskTiling1D
// vtkm::exec::internal::TaskTiling3D
auto task = TaskTypes::MakeTask(this->Worklet, invocation, range);
auto task = TaskTypes::MakeTask(this->Worklet, invocationForDevice, range);
Algorithm::ScheduleTask(task, range);
}
};

@ -113,6 +113,15 @@ public:
///
using VisitIndex = vtkm::exec::arg::VisitIndex;
/// \c ExecutionSignature tag for getting the device adapter tag.
///
struct Device : vtkm::exec::arg::ExecutionSignatureTagBase
{
// INDEX 0 (which is an invalid parameter index) is reserved to mean the device adapter tag.
static constexpr vtkm::IdComponent INDEX = 0;
using AspectTag = vtkm::exec::arg::AspectTagDefault;
};
/// \c ControlSignature tag for execution object inputs.
struct ExecObject : vtkm::cont::arg::ControlSignatureTagBase
{

@ -19,6 +19,7 @@
//============================================================================
#include <vtkm/cont/ArrayHandle.h>
#include <vtkm/cont/ArrayHandleIndex.h>
#include <vtkm/cont/ExecutionObjectBase.h>
#include <vtkm/cont/VariantArrayHandle.h>
#include <vtkm/cont/internal/DeviceAdapterTag.h>
@ -27,21 +28,34 @@
#include <vtkm/cont/testing/Testing.h>
struct SimpleExecObject : vtkm::cont::ExecutionObjectBase
{
template <typename Device>
Device PrepareForExecution(Device) const
{
return Device();
}
};
struct TestExecObjectWorklet
{
template <typename T>
class Worklet : public vtkm::worklet::WorkletMapField
{
public:
using ControlSignature = void(FieldIn, WholeArrayIn, WholeArrayOut, FieldOut);
using ExecutionSignature = void(_1, _2, _3, _4);
using ControlSignature = void(FieldIn, WholeArrayIn, WholeArrayOut, FieldOut, ExecObject);
using ExecutionSignature = void(_1, _2, _3, _4, _5, Device);
template <typename InPortalType, typename OutPortalType>
template <typename InPortalType, typename OutPortalType, typename DeviceTag>
VTKM_EXEC void operator()(const vtkm::Id& index,
const InPortalType& execIn,
OutPortalType& execOut,
T& out) const
T& out,
DeviceTag,
DeviceTag) const
{
VTKM_IS_DEVICE_ADAPTER_TAG(DeviceTag);
if (!test_equal(execIn.Get(index), TestValue(index, T()) + T(100)))
{
this->RaiseError("Got wrong input value.");
@ -79,7 +93,7 @@ struct DoTestWorklet
std::cout << "Create and run dispatcher." << std::endl;
vtkm::worklet::DispatcherMapField<typename WorkletType::template Worklet<T>> dispatcher;
dispatcher.Invoke(counting, inputHandle, outputHandle, outputFieldArray);
dispatcher.Invoke(counting, inputHandle, outputHandle, outputFieldArray, SimpleExecObject());
std::cout << "Check result." << std::endl;
CheckPortal(outputHandle.GetPortalConstControl());
@ -92,7 +106,7 @@ struct DoTestWorklet
outputHandle.Allocate(ARRAY_SIZE);
vtkm::cont::VariantArrayHandleBase<vtkm::ListTagBase<T>> outputFieldDynamic(outputFieldArray);
dispatcher.Invoke(counting, inputHandle, outputHandle, outputFieldDynamic);
dispatcher.Invoke(counting, inputHandle, outputHandle, outputFieldDynamic, SimpleExecObject());
std::cout << "Check dynamic array result." << std::endl;
CheckPortal(outputHandle.GetPortalConstControl());