vtk-m2/vtkm/cont/testing/UnitTestArrayHandleThreadSafety.cxx
Kenneth Moreland ef3f544a67 Add ability to attach token to general ArrayHandle
Duplicated the new versions of PrepareFor* methods from the basic
ArrayHandle that take a token in addition to the other arguments. The
ArrayHandle attaches itself to the token and will not allow operaitons
that make the returned portal invalid until the token goes out of scope.

Later the old versions will be deprecated.
2020-02-25 07:41:37 -07:00

211 lines
6.4 KiB
C++

//============================================================================
// Copyright (c) Kitware, Inc.
// All rights reserved.
// See LICENSE.txt for details.
//
// This software is distributed WITHOUT ANY WARRANTY; without even
// the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
// PURPOSE. See the above copyright notice for more information.
//============================================================================
#include <vtkm/cont/ArrayHandle.h>
#include <vtkm/cont/ArrayHandleIndex.h>
#include <vtkm/cont/ArrayHandlePermutation.h>
#include <vtkm/cont/Token.h>
#include <vtkm/cont/serial/DeviceAdapterSerial.h>
#include <vtkm/cont/testing/Testing.h>
#include <array>
#include <future>
namespace
{
constexpr vtkm::Id ARRAY_SIZE = 10;
constexpr std::size_t NUM_THREADS = 20;
using ValueType = vtkm::FloatDefault;
template <typename Storage>
bool IncrementArray(vtkm::cont::ArrayHandle<ValueType, Storage> array)
{
vtkm::cont::Token token;
auto portal = array.PrepareForInPlace(vtkm::cont::DeviceAdapterTagSerial{}, token);
if (portal.GetNumberOfValues() != ARRAY_SIZE)
{
std::cout << "!!!!! Wrong array size: " << portal.GetNumberOfValues();
return false;
}
for (vtkm::Id index = 0; index < ARRAY_SIZE; ++index)
{
ValueType value = portal.Get(index);
ValueType base = TestValue(index, ValueType{});
if ((value < base) || (value >= base + static_cast<ValueType>(NUM_THREADS)))
{
std::cout << "!!!!! Unexpected value in array: " << value << std::endl;
return false;
}
portal.Set(index, value + 1);
}
return true;
}
template <typename Storage>
bool CheckArray(vtkm::cont::ArrayHandle<ValueType, Storage> array)
{
vtkm::cont::Token token;
auto portal = array.PrepareForInput(vtkm::cont::DeviceAdapterTagSerial{}, token);
if (portal.GetNumberOfValues() != ARRAY_SIZE)
{
std::cout << "!!!!! Wrong array size: " << portal.GetNumberOfValues();
return false;
}
for (vtkm::Id index = 0; index < ARRAY_SIZE; ++index)
{
ValueType value = portal.Get(index);
ValueType expectedValue = TestValue(index, value) + static_cast<ValueType>(NUM_THREADS);
if (!test_equal(value, expectedValue))
{
std::cout << "!!!!! Unexpected value in array: " << value << std::endl;
return false;
}
}
return true;
}
template <typename Storage>
bool DecrementArray(vtkm::cont::ArrayHandle<ValueType, Storage> array)
{
vtkm::cont::Token token;
auto portal = array.PrepareForInPlace(vtkm::cont::DeviceAdapterTagSerial{}, token);
if (portal.GetNumberOfValues() != ARRAY_SIZE)
{
std::cout << "!!!!! Wrong array size: " << portal.GetNumberOfValues();
return false;
}
for (vtkm::Id index = 0; index < ARRAY_SIZE; ++index)
{
ValueType value = portal.Get(index);
ValueType base = TestValue(index, value);
if ((value <= base) || (value >= base + static_cast<ValueType>(NUM_THREADS) + 1))
{
std::cout << "!!!!! Unexpected value in array: " << value << std::endl;
return false;
}
portal.Set(index, value - 1);
}
return true;
}
template <typename Storage>
void ThreadsIncrementToArray(vtkm::cont::ArrayHandle<ValueType, Storage> array)
{
vtkm::cont::Token token;
auto portal = array.PrepareForOutput(ARRAY_SIZE, vtkm::cont::DeviceAdapterTagSerial{}, token);
std::cout << " Starting write threads" << std::endl;
std::array<decltype(std::async(std::launch::async, IncrementArray<Storage>, array)), NUM_THREADS>
futures;
for (std::size_t index = 0; index < NUM_THREADS; ++index)
{
futures[index] = std::async(std::launch::async, IncrementArray<Storage>, array);
}
std::cout << " Filling array" << std::endl;
for (vtkm::Id index = 0; index < ARRAY_SIZE; ++index)
{
portal.Set(index, TestValue(index, ValueType{}));
}
std::cout << " Releasing portal" << std::endl;
token.DetachFromAll();
std::cout << " Wait for threads to complete" << std::endl;
for (std::size_t index = 0; index < NUM_THREADS; ++index)
{
bool futureResult = futures[index].get();
VTKM_TEST_ASSERT(futureResult, "Failure in IncrementArray");
}
}
template <typename Storage>
void ThreadsCheckArray(vtkm::cont::ArrayHandle<ValueType, Storage> array)
{
std::cout << " Check array in control environment" << std::endl;
auto portal = array.GetPortalConstControl();
VTKM_TEST_ASSERT(portal.GetNumberOfValues() == ARRAY_SIZE);
std::cout << " Starting threads to check" << std::endl;
std::array<decltype(std::async(std::launch::async, CheckArray<Storage>, array)), NUM_THREADS>
futures;
for (std::size_t index = 0; index < NUM_THREADS; ++index)
{
futures[index] = std::async(std::launch::async, CheckArray<Storage>, array);
}
std::cout << " Wait for threads to complete" << std::endl;
for (std::size_t index = 0; index < NUM_THREADS; ++index)
{
bool futureResult = futures[index].get();
VTKM_TEST_ASSERT(futureResult, "Failure in CheckArray");
}
}
template <typename Storage>
void ThreadsDecrementArray(vtkm::cont::ArrayHandle<ValueType, Storage> array)
{
std::cout << " Starting threads to decrement" << std::endl;
std::array<decltype(std::async(std::launch::async, DecrementArray<Storage>, array)), NUM_THREADS>
futures;
for (std::size_t index = 0; index < NUM_THREADS; ++index)
{
futures[index] = std::async(std::launch::async, DecrementArray<Storage>, array);
}
std::cout << " Wait for threads to complete" << std::endl;
for (std::size_t index = 0; index < NUM_THREADS; ++index)
{
bool futureResult = futures[index].get();
VTKM_TEST_ASSERT(futureResult, "Failure in DecrementArray");
}
CheckPortal(array.GetPortalConstControl());
}
template <typename Storage>
void DoThreadSafetyTest(vtkm::cont::ArrayHandle<ValueType, Storage> array)
{
ThreadsIncrementToArray(array);
ThreadsCheckArray(array);
ThreadsDecrementArray(array);
}
void DoTest()
{
std::cout << "Basic array handle." << std::endl;
vtkm::cont::ArrayHandle<ValueType> basicArray;
DoThreadSafetyTest(basicArray);
std::cout << "Fancy array handle." << std::endl;
vtkm::cont::ArrayHandle<ValueType> valueArray;
valueArray.Allocate(ARRAY_SIZE);
auto fancyArray =
vtkm::cont::make_ArrayHandlePermutation(vtkm::cont::ArrayHandleIndex(ARRAY_SIZE), valueArray);
DoThreadSafetyTest(fancyArray);
}
} // anonymous namespace
int UnitTestArrayHandleThreadSafety(int argc, char* argv[])
{
return vtkm::cont::testing::Testing::Run(DoTest, argc, argv);
}