Merge topic 'remove-device-from-worklet-run-function'
5426ae0ed added runtime device tracker force cb0323880 removed device from run function for AverageByKey Acked-by: Kitware Robot <kwrobot@kitware.com> Acked-by: Kenneth Moreland <kmorel@sandia.gov> Merge-request: !1440
This commit is contained in:
commit
41b006064d
@ -69,17 +69,13 @@ struct AverageByKey
|
||||
template <typename KeyType,
|
||||
typename ValueType,
|
||||
typename InValuesStorage,
|
||||
typename OutAveragesStorage,
|
||||
typename Device>
|
||||
typename OutAveragesStorage>
|
||||
VTKM_CONT static void Run(const vtkm::worklet::Keys<KeyType>& keys,
|
||||
const vtkm::cont::ArrayHandle<ValueType, InValuesStorage>& inValues,
|
||||
vtkm::cont::ArrayHandle<ValueType, OutAveragesStorage>& outAverages,
|
||||
Device)
|
||||
vtkm::cont::ArrayHandle<ValueType, OutAveragesStorage>& outAverages)
|
||||
{
|
||||
VTKM_IS_DEVICE_ADAPTER_TAG(Device);
|
||||
|
||||
vtkm::worklet::DispatcherReduceByKey<AverageWorklet> dispatcher;
|
||||
dispatcher.SetDevice(Device());
|
||||
dispatcher.Invoke(keys, inValues, outAverages);
|
||||
}
|
||||
|
||||
@ -88,16 +84,14 @@ struct AverageByKey
|
||||
/// This method uses an existing \c Keys object to collected values by those keys and find
|
||||
/// the average of those groups.
|
||||
///
|
||||
template <typename KeyType, typename ValueType, typename InValuesStorage, typename Device>
|
||||
template <typename KeyType, typename ValueType, typename InValuesStorage>
|
||||
VTKM_CONT static vtkm::cont::ArrayHandle<ValueType> Run(
|
||||
const vtkm::worklet::Keys<KeyType>& keys,
|
||||
const vtkm::cont::ArrayHandle<ValueType, InValuesStorage>& inValues,
|
||||
Device)
|
||||
const vtkm::cont::ArrayHandle<ValueType, InValuesStorage>& inValues)
|
||||
{
|
||||
VTKM_IS_DEVICE_ADAPTER_TAG(Device);
|
||||
|
||||
vtkm::cont::ArrayHandle<ValueType> outAverages;
|
||||
Run(keys, inValues, outAverages, Device());
|
||||
Run(keys, inValues, outAverages);
|
||||
return outAverages;
|
||||
}
|
||||
|
||||
@ -136,15 +130,13 @@ struct AverageByKey
|
||||
class KeyInStorage,
|
||||
class KeyOutStorage,
|
||||
class ValueInStorage,
|
||||
class ValueOutStorage,
|
||||
class DeviceAdapter>
|
||||
class ValueOutStorage>
|
||||
VTKM_CONT static void Run(const vtkm::cont::ArrayHandle<KeyType, KeyInStorage>& keyArray,
|
||||
const vtkm::cont::ArrayHandle<ValueType, ValueInStorage>& valueArray,
|
||||
vtkm::cont::ArrayHandle<KeyType, KeyOutStorage>& outputKeyArray,
|
||||
vtkm::cont::ArrayHandle<ValueType, ValueOutStorage>& outputValueArray,
|
||||
DeviceAdapter)
|
||||
vtkm::cont::ArrayHandle<ValueType, ValueOutStorage>& outputValueArray)
|
||||
{
|
||||
using Algorithm = vtkm::cont::DeviceAdapterAlgorithm<DeviceAdapter>;
|
||||
using Algorithm = vtkm::cont::Algorithm;
|
||||
using ValueInArray = vtkm::cont::ArrayHandle<ValueType, ValueInStorage>;
|
||||
using IdArray = vtkm::cont::ArrayHandle<vtkm::Id>;
|
||||
using ValueArray = vtkm::cont::ArrayHandle<ValueType>;
|
||||
@ -177,7 +169,6 @@ struct AverageByKey
|
||||
|
||||
// get average
|
||||
DispatcherMapField<DivideWorklet> dispatcher;
|
||||
dispatcher.SetDevice(DeviceAdapter());
|
||||
dispatcher.Invoke(sumArray, countArray, outputValueArray);
|
||||
}
|
||||
};
|
||||
|
@ -82,17 +82,15 @@ void TryKeyType(KeyType)
|
||||
|
||||
// Create values array
|
||||
vtkm::cont::ArrayHandleCounting<vtkm::FloatDefault> valuesArray(0.0f, 1.0f, ARRAY_SIZE);
|
||||
vtkm::cont::GetGlobalRuntimeDeviceTracker().ForceDevice(VTKM_DEFAULT_DEVICE_ADAPTER_TAG());
|
||||
|
||||
std::cout << " Try average with Keys object" << std::endl;
|
||||
CheckAverageByKey(
|
||||
keys.GetUniqueKeys(),
|
||||
vtkm::worklet::AverageByKey::Run(keys, valuesArray, VTKM_DEFAULT_DEVICE_ADAPTER_TAG()));
|
||||
CheckAverageByKey(keys.GetUniqueKeys(), vtkm::worklet::AverageByKey::Run(keys, valuesArray));
|
||||
|
||||
std::cout << " Try average with device adapter's reduce by keys" << std::endl;
|
||||
vtkm::cont::ArrayHandle<KeyType> outputKeys;
|
||||
vtkm::cont::ArrayHandle<vtkm::FloatDefault> outputValues;
|
||||
vtkm::worklet::AverageByKey::Run(
|
||||
keysArray, valuesArray, outputKeys, outputValues, VTKM_DEFAULT_DEVICE_ADAPTER_TAG());
|
||||
vtkm::worklet::AverageByKey::Run(keysArray, valuesArray, outputKeys, outputValues);
|
||||
CheckAverageByKey(outputKeys, outputValues);
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user