Add input range to arguments of transport

Previously the arguments to the operator of a vtkm::cont::arg::Transport
were the control object, the input domain object, and the output range.
If you wanted to, for example, check the size of an input array to make
sure it matched the input range, you would have to know the meaning of
the input domain object to query its range. This made it hard to create
generic transports, like TransportTagArrayIn, that accept data from
multiple different input domains but need to know the input range.
This commit is contained in:
Kenneth Moreland 2017-03-24 14:51:46 -06:00
parent 2481dd6248
commit dc192b793d
20 changed files with 93 additions and 38 deletions

@ -48,13 +48,11 @@ struct Transport<vtkm::cont::arg::TransportTagArrayIn, ContObjectType, Device>
template<typename InputDomainType>
VTKM_CONT
ExecObjectType operator()(const ContObjectType &object,
const InputDomainType &inputDomain,
vtkm::Id) const
const InputDomainType &vtkmNotUsed(inputDomain),
vtkm::Id inputRange,
vtkm::Id vtkmNotUsed(outputRange)) const
{
// This transport expects the input domain to be an array handle.
VTKM_IS_ARRAY_HANDLE(InputDomainType);
if (object.GetNumberOfValues() != inputDomain.GetNumberOfValues())
if (object.GetNumberOfValues() != inputRange)
{
throw vtkm::cont::ErrorBadValue(
"Input array to worklet invocation the wrong size.");

@ -51,13 +51,14 @@ struct Transport<vtkm::cont::arg::TransportTagArrayInOut, ContObjectType, Device
template<typename InputDomainType>
VTKM_CONT
ExecObjectType operator()(ContObjectType object,
const InputDomainType &,
vtkm::Id size) const
const InputDomainType &vtkmNotUsed(inputDomain),
vtkm::Id vtkmNotUsed(inputRange),
vtkm::Id outputRange) const
{
if (object.GetNumberOfValues() != size)
if (object.GetNumberOfValues() != outputRange)
{
throw vtkm::cont::ErrorBadValue(
"Input array to worklet invocation the wrong size.");
"Input/output array to worklet invocation the wrong size.");
}
return object.PrepareForInPlace(Device());

@ -50,10 +50,11 @@ struct Transport<vtkm::cont::arg::TransportTagArrayOut, ContObjectType, Device>
template<typename InputDomainType>
VTKM_CONT
ExecObjectType operator()(ContObjectType object,
const InputDomainType &,
vtkm::Id size) const
const InputDomainType &vtkmNotUsed(inputDomain),
vtkm::Id vtkmNotUsed(inputRange),
vtkm::Id outputRange) const
{
return object.PrepareForOutput(size, Device());
return object.PrepareForOutput(outputRange, Device());
}
};

@ -56,6 +56,7 @@ struct Transport<
ExecObjectType operator()(
vtkm::cont::ArrayHandle<T, vtkm::cont::StorageTagBasic> array,
const InputDomainType &,
vtkm::Id,
vtkm::Id) const
{
// Note: we ignore the size of the domain because the randomly accessed

@ -55,6 +55,7 @@ struct Transport<vtkm::cont::arg::TransportTagCellSetIn<FromTopology,ToTopology>
VTKM_CONT
ExecObjectType operator()(const ContObjectType &object,
const InputDomainType &,
vtkm::Id,
vtkm::Id) const
{
return object.PrepareForInput(Device(),

@ -54,6 +54,7 @@ struct Transport<vtkm::cont::arg::TransportTagExecObject,ContObjectType,Device>
VTKM_CONT
ExecObjectType operator()(const ContObjectType &object,
const InputDomainType &,
vtkm::Id,
vtkm::Id) const
{
return object;

@ -88,6 +88,7 @@ struct Transport<
VTKM_CONT
ExecObjectType operator()(const ContObjectType &object,
const vtkm::cont::CellSet &inputDomain,
vtkm::Id,
vtkm::Id) const
{
if (object.GetNumberOfValues() !=

@ -60,6 +60,7 @@ struct Transport<
VTKM_CONT
ExecObjectType operator()(ContObjectType array,
const InputDomainType &,
vtkm::Id,
vtkm::Id) const
{
// Note: we ignore the size of the domain because the randomly accessed

@ -62,6 +62,7 @@ struct Transport<
VTKM_CONT
ExecObjectType operator()(ContObjectType array,
const InputDomainType &,
vtkm::Id,
vtkm::Id) const
{
// Note: we ignore the size of the domain because the randomly accessed

@ -62,6 +62,7 @@ struct Transport<
VTKM_CONT
ExecObjectType operator()(ContObjectType array,
const InputDomainType &,
vtkm::Id,
vtkm::Id) const
{
// Note: we ignore the size of the domain because the randomly accessed

@ -70,7 +70,7 @@ struct TryArrayInType
transport;
TestKernel<PortalType> kernel;
kernel.Portal = transport(handle, handle, ARRAY_SIZE);
kernel.Portal = transport(handle, handle, ARRAY_SIZE, ARRAY_SIZE);
vtkm::cont::DeviceAdapterAlgorithm<Device>::Schedule(kernel, ARRAY_SIZE);
}

@ -68,7 +68,7 @@ struct TryArrayInOutType
transport;
TestKernel<PortalType> kernel;
kernel.Portal = transport(handle, handle, ARRAY_SIZE);
kernel.Portal = transport(handle, handle, ARRAY_SIZE, ARRAY_SIZE);
vtkm::cont::DeviceAdapterAlgorithm<Device>::Schedule(kernel, ARRAY_SIZE);

@ -64,6 +64,7 @@ struct TryArrayOutType
TestKernel<PortalType> kernel;
kernel.Portal = transport(handle,
vtkm::cont::ArrayHandleIndex(ARRAY_SIZE),
ARRAY_SIZE,
ARRAY_SIZE);
VTKM_TEST_ASSERT(handle.GetNumberOfValues() == ARRAY_SIZE,

@ -84,7 +84,7 @@ void TransportWholeCellSetIn(Device)
transport;
TestKernel<ExecObjectType> kernel;
kernel.CellSet = transport(contObject, nullptr, 1);
kernel.CellSet = transport(contObject, nullptr, 1, 1);
vtkm::cont::DeviceAdapterAlgorithm<Device>::Schedule(kernel, 1);
}

@ -61,7 +61,7 @@ void TryExecObjectTransport(Device)
transport;
TestKernel kernel;
kernel.Object = transport(contObject, nullptr, 1);
kernel.Object = transport(contObject, nullptr, 1, 1);
vtkm::cont::DeviceAdapterAlgorithm<Device>::Schedule(kernel, 1);
}

@ -131,7 +131,7 @@ struct TryWholeArrayType
std::cout << "Check Transport WholeArrayOut" << std::endl;
TestOutKernel<typename OutTransportType::ExecObjectType> outKernel;
outKernel.Portal = OutTransportType()(array, nullptr, -1);
outKernel.Portal = OutTransportType()(array, nullptr, -1, -1);
vtkm::cont::DeviceAdapterAlgorithm<Device>::Schedule(outKernel, ARRAY_SIZE);
@ -139,13 +139,13 @@ struct TryWholeArrayType
std::cout << "Check Transport WholeArrayIn" << std::endl;
TestInKernel<typename InTransportType::ExecObjectType> inKernel;
inKernel.Portal = InTransportType()(array, nullptr, -1);
inKernel.Portal = InTransportType()(array, nullptr, -1, -1);
vtkm::cont::DeviceAdapterAlgorithm<Device>::Schedule(inKernel, ARRAY_SIZE);
std::cout << "Check Transport WholeArrayInOut" << std::endl;
TestInOutKernel<typename InOutTransportType::ExecObjectType> inOutKernel;
inOutKernel.Portal = InOutTransportType()(array, nullptr, -1);
inOutKernel.Portal = InOutTransportType()(array, nullptr, -1, -1);
vtkm::cont::DeviceAdapterAlgorithm<Device>::Schedule(inOutKernel, ARRAY_SIZE);
@ -180,7 +180,7 @@ struct TryAtomicArrayType
std::cout << "Check Transport AtomicArray" << std::endl;
TestAtomicKernel<typename TransportType::ExecObjectType>
kernel(TransportType()(array, nullptr, -1));
kernel(TransportType()(array, nullptr, -1, -1));
vtkm::cont::DeviceAdapterAlgorithm<Device>::Schedule(kernel, ARRAY_SIZE);

@ -274,6 +274,7 @@ private:
ExecObjectParameters execObjectParameters =
parameters.StaticTransformCont(TransportFunctorType(
invocation.GetInputDomain(),
inputRange,
outputRange));
// Get the arrays used for scattering input to output.

@ -152,6 +152,12 @@ public:
(this->Counts == other.Counts));
}
VTKM_CONT
bool operator!=(const vtkm::worklet::Keys<KeyType> &other) const
{
return !(*this == other);
}
private:
KeyArrayHandleType UniqueKeys;
vtkm::cont::ArrayHandle<vtkm::Id> SortedValuesMap;
@ -221,9 +227,14 @@ struct Transport<vtkm::cont::arg::TransportTagKeysIn,
VTKM_CONT
ExecObjectType operator()(const ContObjectType &object,
const ContObjectType &inputDomain,
vtkm::Id,
vtkm::Id) const
{
VTKM_ASSERT(object == inputDomain);
if (object != inputDomain)
{
throw vtkm::cont::ErrorBadValue(
"A Keys object must be the input domain.");
}
return object.PrepareForInput(Device());
}
@ -234,6 +245,7 @@ struct Transport<vtkm::cont::arg::TransportTagKeysIn,
VTKM_CONT
ExecObjectType operator()(const ContObjectType &,
const InputDomainType &,
vtkm::Id,
vtkm::Id) const = delete;
};
@ -258,9 +270,13 @@ struct Transport<
VTKM_CONT
ExecObjectType operator()(const ContObjectType &object,
const vtkm::worklet::Keys<KeyType> &keys,
vtkm::Id,
vtkm::Id) const
{
VTKM_ASSERT(object.GetNumberOfValues() == keys.GetNumberOfValues());
if (object.GetNumberOfValues() != keys.GetNumberOfValues())
{
throw vtkm::cont::ErrorBadValue("Input values array is wrong size.");
}
PermutedArrayType permutedArray(keys.GetSortedValuesMap(), object);
GroupedArrayType groupedArray(permutedArray, keys.GetOffsets());
@ -294,9 +310,14 @@ struct Transport<
VTKM_CONT
ExecObjectType operator()(ContObjectType object,
const vtkm::worklet::Keys<KeyType> &keys,
vtkm::Id,
vtkm::Id) const
{
VTKM_ASSERT(object.GetNumberOfValues() == keys.GetNumberOfValues());
if (object.GetNumberOfValues() != keys.GetNumberOfValues())
{
throw vtkm::cont::ErrorBadValue(
"Input/output values array is wrong size.");
}
PermutedArrayType permutedArray(keys.GetSortedValuesMap(), object);
GroupedArrayType groupedArray(permutedArray, keys.GetOffsets());
@ -330,6 +351,7 @@ struct Transport<
VTKM_CONT
ExecObjectType operator()(ContObjectType object,
const vtkm::worklet::Keys<KeyType> &keys,
vtkm::Id,
vtkm::Id) const
{
// The PrepareForOutput for ArrayHandleGroupVecVariable and
@ -361,9 +383,13 @@ struct Transport<
VTKM_CONT
ExecObjectType operator()(const ContObjectType &object,
const vtkm::worklet::Keys<KeyType> &inputDomain,
vtkm::Id inputRange,
vtkm::Id) const
{
if (object.GetNumberOfValues() != inputDomain.GetInputRange())
VTKM_ASSERT(inputDomain.GetInputRange() == inputRange);
(void)inputDomain; // Shut up compiler
if (object.GetNumberOfValues() != inputRange)
{
throw vtkm::cont::ErrorBadValue(
"Input array to worklet invocation the wrong size.");

@ -215,30 +215,41 @@ struct DispatcherBaseTransportInvokeTypes
typedef typename ControlSignatureTag::TransportTag TransportTag;
};
VTKM_CONT
inline
vtkm::Id FlatRange(vtkm::Id range)
{
return range;
}
VTKM_CONT
inline
vtkm::Id FlatRange(const vtkm::Id3 &range)
{
return range[0]*range[1]*range[2];
}
// A functor used in a StaticCast of a FunctionInterface to transport arguments
// from the control environment to the execution environment.
template<typename ControlInterface, typename InputDomainType, typename Device>
struct DispatcherBaseTransportFunctor
{
const InputDomainType &InputDomain; // Warning: this is a reference
vtkm::Id OutputSize;
VTKM_CONT
DispatcherBaseTransportFunctor(const InputDomainType &inputDomain,
vtkm::Id outputSize)
: InputDomain(inputDomain),
OutputSize(outputSize)
{ }
vtkm::Id InputRange;
vtkm::Id OutputRange;
// TODO: We need to think harder about how scheduling on 3D arrays works.
// Chances are we need to allow the transport for each argument to manage
// 3D indices (for example, allocate a 3D array instead of a 1D array).
// But for now, just treat all transports as 1D arrays.
template<typename InputRange, typename OutputRange>
VTKM_CONT
DispatcherBaseTransportFunctor(const InputDomainType &inputDomain,
vtkm::Id3 dimensions)
const InputRange &inputRange,
const OutputRange &outputRange)
: InputDomain(inputDomain),
OutputSize(dimensions[0]*dimensions[1]*dimensions[2])
InputRange(FlatRange(inputRange)),
OutputRange(FlatRange(outputRange))
{ }
@ -257,7 +268,10 @@ struct DispatcherBaseTransportFunctor
{
using TransportTag = typename DispatcherBaseTransportInvokeTypes<ControlInterface, Index>::TransportTag;
vtkm::cont::arg::Transport<TransportTag,ControlParameter,Device> transport;
return transport(invokeData, this->InputDomain, this->OutputSize);
return transport(invokeData,
this->InputDomain,
this->InputRange,
this->OutputRange);
}
private:
@ -491,6 +505,7 @@ private:
ExecObjectParameters execObjectParameters =
parameters.StaticTransformCont(TransportFunctorType(
invocation.GetInputDomain(),
inputRange,
outputRange));
// Get the arrays used for scattering input to output.

@ -78,9 +78,14 @@ struct Transport<TestTransportTag, vtkm::Id *, Device>
typedef TestExecObject ExecObjectType;
VTKM_CONT
ExecObjectType operator()(vtkm::Id *contData, vtkm::Id *, vtkm::Id size) const
ExecObjectType operator()(vtkm::Id *contData,
vtkm::Id *,
vtkm::Id inputRange,
vtkm::Id outputRange) const
{
VTKM_TEST_ASSERT(size == ARRAY_SIZE,
VTKM_TEST_ASSERT(inputRange == ARRAY_SIZE,
"Got unexpected size in test transport.");
VTKM_TEST_ASSERT(outputRange == ARRAY_SIZE,
"Got unexpected size in test transport.");
return ExecObjectType(contData);
}