mirror of
https://gitlab.kitware.com/vtk/vtk-m
synced 2024-09-20 19:15:45 +00:00
ec34cb56c4
Also fix deadlocks that occur when portals are not destroyed in time.
211 lines
6.4 KiB
C++
211 lines
6.4 KiB
C++
//============================================================================
|
|
// Copyright (c) Kitware, Inc.
|
|
// All rights reserved.
|
|
// See LICENSE.txt for details.
|
|
//
|
|
// This software is distributed WITHOUT ANY WARRANTY; without even
|
|
// the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
|
|
// PURPOSE. See the above copyright notice for more information.
|
|
//============================================================================
|
|
|
|
#include <vtkm/cont/ArrayHandle.h>
|
|
#include <vtkm/cont/ArrayHandleIndex.h>
|
|
#include <vtkm/cont/ArrayHandlePermutation.h>
|
|
#include <vtkm/cont/Token.h>
|
|
|
|
#include <vtkm/cont/serial/DeviceAdapterSerial.h>
|
|
|
|
#include <vtkm/cont/testing/Testing.h>
|
|
|
|
#include <array>
|
|
#include <future>
|
|
|
|
namespace
|
|
{
|
|
|
|
constexpr vtkm::Id ARRAY_SIZE = 10;
|
|
constexpr std::size_t NUM_THREADS = 20;
|
|
|
|
using ValueType = vtkm::FloatDefault;
|
|
|
|
template <typename Storage>
|
|
bool IncrementArray(vtkm::cont::ArrayHandle<ValueType, Storage> array)
|
|
{
|
|
vtkm::cont::Token token;
|
|
auto portal = array.PrepareForInPlace(vtkm::cont::DeviceAdapterTagSerial{}, token);
|
|
if (portal.GetNumberOfValues() != ARRAY_SIZE)
|
|
{
|
|
std::cout << "!!!!! Wrong array size: " << portal.GetNumberOfValues();
|
|
return false;
|
|
}
|
|
|
|
for (vtkm::Id index = 0; index < ARRAY_SIZE; ++index)
|
|
{
|
|
ValueType value = portal.Get(index);
|
|
ValueType base = TestValue(index, ValueType{});
|
|
if ((value < base) || (value >= base + static_cast<ValueType>(NUM_THREADS)))
|
|
{
|
|
std::cout << "!!!!! Unexpected value in array: " << value << std::endl;
|
|
return false;
|
|
}
|
|
portal.Set(index, value + 1);
|
|
}
|
|
|
|
return true;
|
|
}
|
|
|
|
template <typename Storage>
|
|
bool CheckArray(vtkm::cont::ArrayHandle<ValueType, Storage> array)
|
|
{
|
|
vtkm::cont::Token token;
|
|
auto portal = array.PrepareForInput(vtkm::cont::DeviceAdapterTagSerial{}, token);
|
|
if (portal.GetNumberOfValues() != ARRAY_SIZE)
|
|
{
|
|
std::cout << "!!!!! Wrong array size: " << portal.GetNumberOfValues();
|
|
return false;
|
|
}
|
|
|
|
for (vtkm::Id index = 0; index < ARRAY_SIZE; ++index)
|
|
{
|
|
ValueType value = portal.Get(index);
|
|
ValueType expectedValue = TestValue(index, value) + static_cast<ValueType>(NUM_THREADS);
|
|
if (!test_equal(value, expectedValue))
|
|
{
|
|
std::cout << "!!!!! Unexpected value in array: " << value << std::endl;
|
|
return false;
|
|
}
|
|
}
|
|
|
|
return true;
|
|
}
|
|
|
|
template <typename Storage>
|
|
bool DecrementArray(vtkm::cont::ArrayHandle<ValueType, Storage> array)
|
|
{
|
|
vtkm::cont::Token token;
|
|
auto portal = array.PrepareForInPlace(vtkm::cont::DeviceAdapterTagSerial{}, token);
|
|
if (portal.GetNumberOfValues() != ARRAY_SIZE)
|
|
{
|
|
std::cout << "!!!!! Wrong array size: " << portal.GetNumberOfValues();
|
|
return false;
|
|
}
|
|
|
|
for (vtkm::Id index = 0; index < ARRAY_SIZE; ++index)
|
|
{
|
|
ValueType value = portal.Get(index);
|
|
ValueType base = TestValue(index, value);
|
|
if ((value <= base) || (value >= base + static_cast<ValueType>(NUM_THREADS) + 1))
|
|
{
|
|
std::cout << "!!!!! Unexpected value in array: " << value << std::endl;
|
|
return false;
|
|
}
|
|
portal.Set(index, value - 1);
|
|
}
|
|
|
|
return true;
|
|
}
|
|
|
|
template <typename Storage>
|
|
void ThreadsIncrementToArray(vtkm::cont::ArrayHandle<ValueType, Storage> array)
|
|
{
|
|
vtkm::cont::Token token;
|
|
auto portal = array.PrepareForOutput(ARRAY_SIZE, vtkm::cont::DeviceAdapterTagSerial{}, token);
|
|
|
|
std::cout << " Starting write threads" << std::endl;
|
|
std::array<decltype(std::async(std::launch::async, IncrementArray<Storage>, array)), NUM_THREADS>
|
|
futures;
|
|
for (std::size_t index = 0; index < NUM_THREADS; ++index)
|
|
{
|
|
futures[index] = std::async(std::launch::async, IncrementArray<Storage>, array);
|
|
}
|
|
|
|
std::cout << " Filling array" << std::endl;
|
|
for (vtkm::Id index = 0; index < ARRAY_SIZE; ++index)
|
|
{
|
|
portal.Set(index, TestValue(index, ValueType{}));
|
|
}
|
|
|
|
std::cout << " Releasing portal" << std::endl;
|
|
token.DetachFromAll();
|
|
|
|
std::cout << " Wait for threads to complete" << std::endl;
|
|
for (std::size_t index = 0; index < NUM_THREADS; ++index)
|
|
{
|
|
bool futureResult = futures[index].get();
|
|
VTKM_TEST_ASSERT(futureResult, "Failure in IncrementArray");
|
|
}
|
|
}
|
|
|
|
template <typename Storage>
|
|
void ThreadsCheckArray(vtkm::cont::ArrayHandle<ValueType, Storage> array)
|
|
{
|
|
std::cout << " Check array in control environment" << std::endl;
|
|
auto portal = array.ReadPortal();
|
|
VTKM_TEST_ASSERT(portal.GetNumberOfValues() == ARRAY_SIZE);
|
|
|
|
std::cout << " Starting threads to check" << std::endl;
|
|
std::array<decltype(std::async(std::launch::async, CheckArray<Storage>, array)), NUM_THREADS>
|
|
futures;
|
|
for (std::size_t index = 0; index < NUM_THREADS; ++index)
|
|
{
|
|
futures[index] = std::async(std::launch::async, CheckArray<Storage>, array);
|
|
}
|
|
|
|
std::cout << " Wait for threads to complete" << std::endl;
|
|
for (std::size_t index = 0; index < NUM_THREADS; ++index)
|
|
{
|
|
bool futureResult = futures[index].get();
|
|
VTKM_TEST_ASSERT(futureResult, "Failure in CheckArray");
|
|
}
|
|
}
|
|
|
|
template <typename Storage>
|
|
void ThreadsDecrementArray(vtkm::cont::ArrayHandle<ValueType, Storage> array)
|
|
{
|
|
std::cout << " Starting threads to decrement" << std::endl;
|
|
std::array<decltype(std::async(std::launch::async, DecrementArray<Storage>, array)), NUM_THREADS>
|
|
futures;
|
|
for (std::size_t index = 0; index < NUM_THREADS; ++index)
|
|
{
|
|
futures[index] = std::async(std::launch::async, DecrementArray<Storage>, array);
|
|
}
|
|
|
|
std::cout << " Wait for threads to complete" << std::endl;
|
|
for (std::size_t index = 0; index < NUM_THREADS; ++index)
|
|
{
|
|
bool futureResult = futures[index].get();
|
|
VTKM_TEST_ASSERT(futureResult, "Failure in DecrementArray");
|
|
}
|
|
|
|
CheckPortal(array.ReadPortal());
|
|
}
|
|
|
|
template <typename Storage>
|
|
void DoThreadSafetyTest(vtkm::cont::ArrayHandle<ValueType, Storage> array)
|
|
{
|
|
ThreadsIncrementToArray(array);
|
|
ThreadsCheckArray(array);
|
|
ThreadsDecrementArray(array);
|
|
}
|
|
|
|
void DoTest()
|
|
{
|
|
std::cout << "Basic array handle." << std::endl;
|
|
vtkm::cont::ArrayHandle<ValueType> basicArray;
|
|
DoThreadSafetyTest(basicArray);
|
|
|
|
std::cout << "Fancy array handle." << std::endl;
|
|
vtkm::cont::ArrayHandle<ValueType> valueArray;
|
|
valueArray.Allocate(ARRAY_SIZE);
|
|
auto fancyArray =
|
|
vtkm::cont::make_ArrayHandlePermutation(vtkm::cont::ArrayHandleIndex(ARRAY_SIZE), valueArray);
|
|
DoThreadSafetyTest(fancyArray);
|
|
}
|
|
|
|
} // anonymous namespace
|
|
|
|
int UnitTestArrayHandleThreadSafety(int argc, char* argv[])
|
|
{
|
|
return vtkm::cont::testing::Testing::Run(DoTest, argc, argv);
|
|
}
|