Add ArrayHandleBitField, a boolean-valued AH backed by a BitField.

This commit is contained in:
Allison Vacanti 2019-03-07 16:45:21 -05:00 committed by Robert Maynard
parent 56cc5c3d3a
commit a66510e819
4 changed files with 352 additions and 1 deletions

@ -13,3 +13,39 @@ order.
The new AtomicInterface classes provide an abstraction into bitwise
atomic operations across control and execution environments and are
used to implement the BitPortals.
BitFields may be used as boolean-typed ArrayHandles using the
ArrayHandleBitField adapter. ArrayHandleBitField uses atomic operations to read
and write bits in the BitField, and is safe to use in concurrent code.
For example, a simple worklet that merges two arrays based on a boolean
condition is tested in TestingBitField:
```
class ConditionalMergeWorklet : public vtkm::worklet::WorkletMapField
{
public:
using ControlSignature = void(FieldIn cond,
FieldIn trueVals,
FieldIn falseVals,
FieldOut result);
using ExecutionSignature = _4(_1, _2, _3);
template <typename T>
VTKM_EXEC T operator()(bool cond, const T& trueVal, const T& falseVal) const
{
return cond ? trueVal : falseVal;
}
};
BitField bits = ...;
auto condArray = vtkm::cont::make_ArrayHandleBitField(bits);
auto trueArray = vtkm::cont::make_ArrayHandleCounting<vtkm::Id>(20, 2, NUM_BITS);
auto falseArray = vtkm::cont::make_ArrayHandleCounting<vtkm::Id>(13, 2, NUM_BITS);
vtkm::cont::ArrayHandle<vtkm::Id> output;
vtkm::worklet::DispatcherMapField<ConditionalMergeWorklet> dispatcher;
dispatcher.Invoke(condArray, trueArray, falseArray, output);
```

@ -0,0 +1,220 @@
//=============================================================================
//
// Copyright (c) Kitware, Inc.
// All rights reserved.
// See LICENSE.txt for details.
//
// This software is distributed WITHOUT ANY WARRANTY; without even
// the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
// PURPOSE. See the above copyright notice for more information.
//
// Copyright 2019 National Technology & Engineering Solutions of Sandia, LLC (NTESS).
// Copyright 2019 UT-Battelle, LLC.
// Copyright 2019 Los Alamos National Security.
//
// Under the terms of Contract DE-NA0003525 with NTESS,
// the U.S. Government retains certain rights in this software.
// Under the terms of Contract DE-AC52-06NA25396 with Los Alamos National
// Laboratory (LANL), the U.S. Government retains certain rights in
// this software.
//
//=============================================================================
#ifndef vtk_m_cont_ArrayHandleBitField_h
#define vtk_m_cont_ArrayHandleBitField_h
#include <vtkm/cont/ArrayHandle.h>
#include <vtkm/cont/BitField.h>
#include <vtkm/cont/Storage.h>
namespace vtkm
{
namespace cont
{
namespace internal
{
template <typename BitPortalType>
class ArrayPortalBitField
{
public:
using ValueType = bool;
VTKM_EXEC_CONT
explicit ArrayPortalBitField(const BitPortalType& portal) noexcept : BitPortal{ portal } {}
VTKM_EXEC_CONT
explicit ArrayPortalBitField(BitPortalType&& portal) noexcept : BitPortal{ std::move(portal) } {}
ArrayPortalBitField() noexcept = default;
ArrayPortalBitField(const ArrayPortalBitField&) noexcept = default;
ArrayPortalBitField(ArrayPortalBitField&&) noexcept = default;
ArrayPortalBitField& operator=(const ArrayPortalBitField&) noexcept = default;
ArrayPortalBitField& operator=(ArrayPortalBitField&&) noexcept = default;
VTKM_EXEC_CONT
vtkm::Id GetNumberOfValues() const noexcept { return this->BitPortal.GetNumberOfBits(); }
VTKM_EXEC_CONT
ValueType Get(vtkm::Id index) const noexcept { return this->BitPortal.GetBit(index); }
VTKM_EXEC_CONT
void Set(vtkm::Id index, ValueType value) const
{
// Use an atomic set so we don't clash with other threads writing nearby
// bits.
this->BitPortal.SetBitAtomic(index, value);
}
private:
BitPortalType BitPortal;
};
struct VTKM_ALWAYS_EXPORT StorageTagBitField
{
};
template <>
class Storage<bool, StorageTagBitField>
{
using BitPortalType = vtkm::cont::detail::BitPortal<vtkm::cont::internal::AtomicInterfaceControl>;
using BitPortalConstType =
vtkm::cont::detail::BitPortalConst<vtkm::cont::internal::AtomicInterfaceControl>;
public:
using ValueType = bool;
using PortalType = vtkm::cont::internal::ArrayPortalBitField<BitPortalType>;
using PortalConstType = vtkm::cont::internal::ArrayPortalBitField<BitPortalConstType>;
explicit VTKM_CONT Storage(const vtkm::cont::BitField& data)
: Data{ data }
{
}
explicit VTKM_CONT Storage(vtkm::cont::BitField&& data) noexcept : Data{ std::move(data) } {}
VTKM_CONT Storage() = default;
VTKM_CONT Storage(const Storage& src) = default;
VTKM_CONT Storage(Storage&& src) noexcept = default;
VTKM_CONT Storage& operator=(const Storage& src) = default;
VTKM_CONT Storage& operator=(Storage&& src) noexcept = default;
VTKM_CONT
PortalType GetPortal() { return PortalType{ this->Data.GetPortalControl() }; }
VTKM_CONT
PortalConstType GetPortalConst() { return PortalConstType{ this->Data.GetPortalConstControl() }; }
VTKM_CONT vtkm::Id GetNumberOfValues() const { return this->Data.GetNumberOfBits(); }
VTKM_CONT void Allocate(vtkm::Id numberOfValues) { this->Data.Allocate(numberOfValues); }
VTKM_CONT void Shrink(vtkm::Id numberOfValues) { this->Data.Shrink(numberOfValues); }
VTKM_CONT void ReleaseResources() { this->Data.ReleaseResources(); }
VTKM_CONT vtkm::cont::BitField GetBitField() const { return this->Data; }
private:
vtkm::cont::BitField Data;
};
template <typename Device>
class ArrayTransfer<bool, StorageTagBitField, Device>
{
using AtomicInterface = AtomicInterfaceExecution<Device>;
using StorageType = Storage<bool, StorageTagBitField>;
using BitPortalExecution = vtkm::cont::detail::BitPortal<AtomicInterface>;
using BitPortalConstExecution = vtkm::cont::detail::BitPortalConst<AtomicInterface>;
public:
using ValueType = bool;
using PortalControl = typename StorageType::PortalType;
using PortalConstControl = typename StorageType::PortalConstType;
using PortalExecution = vtkm::cont::internal::ArrayPortalBitField<BitPortalExecution>;
using PortalConstExecution = vtkm::cont::internal::ArrayPortalBitField<BitPortalConstExecution>;
VTKM_CONT
explicit ArrayTransfer(StorageType* storage)
: Data{ storage->GetBitField() }
{
}
VTKM_CONT
vtkm::Id GetNumberOfValues() const { return this->Data.GetNumberOfBits(); }
VTKM_CONT
PortalConstExecution PrepareForInput(bool vtkmNotUsed(updateData))
{
return PortalConstExecution{ this->Data.PrepareForInput(Device{}) };
}
VTKM_CONT
PortalExecution PrepareForInPlace(bool vtkmNotUsed(updateData))
{
return PortalExecution{ this->Data.PrepareForInPlace(Device{}) };
}
VTKM_CONT
PortalExecution PrepareForOutput(vtkm::Id numberOfValues)
{
return PortalExecution{ this->Data.PrepareForOutput(numberOfValues, Device{}) };
}
VTKM_CONT
void RetrieveOutputData(StorageType* vtkmNotUsed(storage)) const
{
// Implementation of this method should be unnecessary. The internal
// bitfield should automatically retrieve the output data as necessary.
}
VTKM_CONT
void Shrink(vtkm::Id numberOfValues) { this->Data.Shrink(numberOfValues); }
VTKM_CONT
void ReleaseResources() { this->Data.ReleaseResources(); }
private:
vtkm::cont::BitField Data;
};
} // end namespace internal
/// The ArrayHandleBitField class is a boolean-valued ArrayHandle that is backed
/// by a BitField.
///
class ArrayHandleBitField : public ArrayHandle<bool, internal::StorageTagBitField>
{
public:
VTKM_ARRAY_HANDLE_SUBCLASS_NT(ArrayHandleBitField,
(ArrayHandle<bool, internal::StorageTagBitField>));
VTKM_CONT
explicit ArrayHandleBitField(const vtkm::cont::BitField& bitField)
: Superclass{ StorageType{ bitField } }
{
}
VTKM_CONT
explicit ArrayHandleBitField(vtkm::cont::BitField&& bitField) noexcept
: Superclass{ StorageType{ std::move(bitField) } }
{
}
VTKM_CONT
vtkm::cont::BitField GetBitField() const { return this->GetStorage().GetBitField(); }
};
VTKM_CONT inline vtkm::cont::ArrayHandleBitField make_ArrayHandleBitField(
const vtkm::cont::BitField& bitField)
{
return ArrayHandleBitField{ bitField };
}
VTKM_CONT inline vtkm::cont::ArrayHandleBitField make_ArrayHandleBitField(
vtkm::cont::BitField&& bitField) noexcept
{
return ArrayHandleBitField{ std::move(bitField) };
}
}
} // end namespace vtkm::cont
#endif // vtk_m_cont_ArrayHandleBitField_h

@ -22,6 +22,7 @@ set(headers
Algorithm.h
ArrayCopy.h
ArrayHandle.h
ArrayHandleBitField.h
ArrayHandleCartesianProduct.h
ArrayHandleCast.h
ArrayHandleCompositeVector.h

@ -20,6 +20,8 @@
#ifndef vtk_m_cont_testing_TestingBitFields_h
#define vtk_m_cont_testing_TestingBitFields_h
#include <vtkm/cont/ArrayHandleBitField.h>
#include <vtkm/cont/ArrayHandleCounting.h>
#include <vtkm/cont/BitField.h>
#include <vtkm/cont/DeviceAdapterAlgorithm.h>
#include <vtkm/cont/RuntimeDeviceTracker.h>
@ -28,6 +30,8 @@
#include <vtkm/exec/FunctorBase.h>
#include <vtkm/worklet/WorkletMapField.h>
#include <cstdio>
#define DEVICE_ASSERT_MSG(cond, message) \
@ -67,6 +71,19 @@ namespace cont
namespace testing
{
class ConditionalMergeWorklet : public vtkm::worklet::WorkletMapField
{
public:
using ControlSignature = void(FieldIn cond, FieldIn trueVals, FieldIn falseVals, FieldOut result);
using ExecutionSignature = _4(_1, _2, _3);
template <typename T>
VTKM_EXEC T operator()(bool cond, const T& trueVal, const T& falseVal) const
{
return cond ? trueVal : falseVal;
}
};
/// This class has a single static member, Run, that runs all tests with the
/// given DeviceAdapter.
template <class DeviceAdapterTag>
@ -514,6 +531,81 @@ struct TestingBitField
testMask64(129, 0x0000000000000001);
}
struct ArrayHandleBitFieldChecker : vtkm::exec::FunctorBase
{
using PortalType = typename ArrayHandleBitField::ExecutionTypes<DeviceAdapterTag>::Portal;
PortalType Portal;
bool InvertReference;
VTKM_EXEC_CONT
ArrayHandleBitFieldChecker(PortalType portal, bool invert)
: Portal(portal)
, InvertReference(invert)
{
}
VTKM_EXEC
void operator()(vtkm::Id i) const
{
const bool ref = this->InvertReference ? !RandomBitFromIndex(i) : RandomBitFromIndex(i);
if (this->Portal.Get(i) != ref)
{
this->RaiseError("Unexpected value from ArrayHandleBitField portal.");
return;
}
// Flip the bit for the next kernel launch, which tests that the bitfield
// is inverted.
this->Portal.Set(i, !ref);
}
};
VTKM_CONT
static void TestArrayHandleBitField()
{
auto handle = vtkm::cont::make_ArrayHandleBitField(RandomBitField());
const vtkm::Id numBits = handle.GetNumberOfValues();
VTKM_TEST_ASSERT(numBits == NUM_BITS,
"ArrayHandleBitField returned the wrong number of values. "
"Expected: ",
NUM_BITS,
" got: ",
numBits);
Algo::Schedule(
ArrayHandleBitFieldChecker{ handle.PrepareForInPlace(DeviceAdapterTag{}), false }, numBits);
Algo::Schedule(ArrayHandleBitFieldChecker{ handle.PrepareForInPlace(DeviceAdapterTag{}), true },
numBits);
}
VTKM_CONT
static void TestArrayInvokeWorklet()
{
auto condArray = vtkm::cont::make_ArrayHandleBitField(RandomBitField());
auto trueArray = vtkm::cont::make_ArrayHandleCounting<vtkm::Id>(20, 2, NUM_BITS);
auto falseArray = vtkm::cont::make_ArrayHandleCounting<vtkm::Id>(13, 2, NUM_BITS);
vtkm::cont::ArrayHandle<vtkm::Id> output;
vtkm::worklet::DispatcherMapField<ConditionalMergeWorklet> dispatcher;
dispatcher.Invoke(condArray, trueArray, falseArray, output);
auto condVals = condArray.GetPortalConstControl();
auto trueVals = trueArray.GetPortalConstControl();
auto falseVals = falseArray.GetPortalConstControl();
auto outVals = output.GetPortalConstControl();
VTKM_TEST_ASSERT(condVals.GetNumberOfValues() == trueVals.GetNumberOfValues());
VTKM_TEST_ASSERT(condVals.GetNumberOfValues() == falseVals.GetNumberOfValues());
VTKM_TEST_ASSERT(condVals.GetNumberOfValues() == outVals.GetNumberOfValues());
for (vtkm::Id i = 0; i < condVals.GetNumberOfValues(); ++i)
{
VTKM_TEST_ASSERT(outVals.Get(i) == (condVals.Get(i) ? trueVals.Get(i) : falseVals.Get(i)));
}
}
struct TestRunner
{
VTKM_CONT
@ -523,13 +615,15 @@ struct TestingBitField
TestingBitField::TestControlPortals();
TestingBitField::TestExecutionPortals();
TestingBitField::TestFinalWordMask();
TestingBitField::TestArrayHandleBitField();
TestingBitField::TestArrayInvokeWorklet();
}
};
public:
static VTKM_CONT int Run(int argc, char* argv[])
{
vtkm::cont::GetGlobalRuntimeDeviceTracker().ForceDevice(DeviceAdapterTag());
vtkm::cont::GetRuntimeDeviceTracker().ForceDevice(DeviceAdapterTag());
return vtkm::cont::testing::Testing::Run(TestRunner{}, argc, argv);
}
};