From 42c6959be3b122152ffac2058228bbbd4fdf89e0 Mon Sep 17 00:00:00 2001 From: Sujin Philip Date: Wed, 1 Feb 2023 10:58:01 -0500 Subject: [PATCH] Add Abort execution feature Initial changes to add support for aborting execution. --- docs/changelog/vtkm-abort.md | 22 ++++ vtkm/cont/CMakeLists.txt | 1 + vtkm/cont/ErrorUserAbort.h | 42 ++++++++ vtkm/cont/RuntimeDeviceTracker.cxx | 50 ++++++---- vtkm/cont/RuntimeDeviceTracker.h | 150 ++++++++++++++-------------- vtkm/cont/TryExecute.cxx | 8 ++ vtkm/cont/TryExecute.h | 6 ++ vtkm/cont/testing/CMakeLists.txt | 6 ++ vtkm/cont/testing/UnitTestAbort.cxx | 105 +++++++++++++++++++ vtkm/cont/vtkm.module | 1 + 10 files changed, 296 insertions(+), 95 deletions(-) create mode 100644 docs/changelog/vtkm-abort.md create mode 100644 vtkm/cont/ErrorUserAbort.h create mode 100644 vtkm/cont/testing/UnitTestAbort.cxx diff --git a/docs/changelog/vtkm-abort.md b/docs/changelog/vtkm-abort.md new file mode 100644 index 000000000..8bd25c7d0 --- /dev/null +++ b/docs/changelog/vtkm-abort.md @@ -0,0 +1,22 @@ +# Add initial support for aborting execution + +VTK-m now has preliminary support for aborting execution. The per-thread instances of +`RuntimeDeviceTracker` have a functor called `AbortChecker`. This functor can be set using +`RuntimeDeviceTracker::SetAbortChecker()` and cleared by `RuntimeDeviceTracker::ClearAbortChecker()` +The abort checker functor should return `true` if an abort is requested for the thread, +otherwise, it should return `false`. + +Before launching a new task, `TaskExecute` calls the functor to see if an abort is requested, +and If so, throws an exception of type `vtkm::cont::ErrorUserAbort`. + +Any code that wants to use the abort feature, should set an appropriate `AbortChecker` +functor for the target thread. Then any piece of code that has parts that can execute on +the device should be put under a `try-catch` block. Any clean-up that is required for an +aborted execution should be handled in a `catch` block that handles exceptions of type +`vtkm::cont::ErrorUserAbort`. + +The limitation of this implementation is that it is control-side only. The check for abort +is done before launching a new device task. Once execution has begun on the device, there is +currently no way to abort that. Therefore, this feature is only useful for aborting code +that is made up of several smaller device task launches (Which is the case for most +worklets and filters in VTK-m) diff --git a/vtkm/cont/CMakeLists.txt b/vtkm/cont/CMakeLists.txt index 35f356527..761652c22 100644 --- a/vtkm/cont/CMakeLists.txt +++ b/vtkm/cont/CMakeLists.txt @@ -94,6 +94,7 @@ set(headers ErrorExecution.h ErrorFilterExecution.h ErrorInternal.h + ErrorUserAbort.h ExecutionAndControlObjectBase.h ExecutionObjectBase.h Field.h diff --git a/vtkm/cont/ErrorUserAbort.h b/vtkm/cont/ErrorUserAbort.h new file mode 100644 index 000000000..04f6371b9 --- /dev/null +++ b/vtkm/cont/ErrorUserAbort.h @@ -0,0 +1,42 @@ +//============================================================================ +// Copyright (c) Kitware, Inc. +// All rights reserved. +// See LICENSE.txt for details. +// +// This software is distributed WITHOUT ANY WARRANTY; without even +// the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR +// PURPOSE. See the above copyright notice for more information. +//============================================================================ +#ifndef vtk_m_cont_ErrorUserAbort_h +#define vtk_m_cont_ErrorUserAbort_h + +#include + +namespace vtkm +{ +namespace cont +{ + +VTKM_SILENCE_WEAK_VTABLE_WARNING_START + +/// This class is thrown when vtk-m detects a request for aborting execution +/// in the current thread +/// +class VTKM_ALWAYS_EXPORT ErrorUserAbort : public Error +{ +public: + ErrorUserAbort() + : Error(Message, true) + { + } + +private: + static constexpr const char* Message = "User abort detected."; +}; + +VTKM_SILENCE_WEAK_VTABLE_WARNING_END + +} +} // namespace vtkm::cont + +#endif // vtk_m_cont_ErrorUserAbort_h diff --git a/vtkm/cont/RuntimeDeviceTracker.cxx b/vtkm/cont/RuntimeDeviceTracker.cxx index ce6fcdf17..89d2e9eab 100644 --- a/vtkm/cont/RuntimeDeviceTracker.cxx +++ b/vtkm/cont/RuntimeDeviceTracker.cxx @@ -35,6 +35,7 @@ struct RuntimeDeviceTrackerInternals std::array RuntimeAllowed; bool ThreadFriendlyMemAlloc = false; + std::function AbortChecker; }; } @@ -186,6 +187,28 @@ VTKM_CONT void RuntimeDeviceTracker::CopyStateFrom(const vtkm::cont::RuntimeDevi *(this->Internals) = *tracker.Internals; } +VTKM_CONT +void RuntimeDeviceTracker::SetAbortChecker(const std::function& func) +{ + this->Internals->AbortChecker = func; +} + +VTKM_CONT +bool RuntimeDeviceTracker::CheckForAbortRequest() const +{ + if (this->Internals->AbortChecker) + { + return this->Internals->AbortChecker(); + } + return false; +} + +VTKM_CONT +void RuntimeDeviceTracker::ClearAbortChecker() +{ + this->Internals->AbortChecker = nullptr; +} + VTKM_CONT void RuntimeDeviceTracker::PrintSummary(std::ostream& out) const { @@ -228,25 +251,6 @@ ScopedRuntimeDeviceTracker::ScopedRuntimeDeviceTracker( VTKM_LOG_S(vtkm::cont::LogLevel::DevicesEnabled, "Entering scoped runtime region"); } -VTKM_CONT -ScopedRuntimeDeviceTracker::ScopedRuntimeDeviceTracker(vtkm::cont::DeviceAdapterId device, - RuntimeDeviceTrackerMode mode) - : ScopedRuntimeDeviceTracker(GetRuntimeDeviceTracker()) -{ - if (mode == RuntimeDeviceTrackerMode::Force) - { - this->ForceDevice(device); - } - else if (mode == RuntimeDeviceTrackerMode::Enable) - { - this->ResetDevice(device); - } - else if (mode == RuntimeDeviceTrackerMode::Disable) - { - this->DisableDevice(device); - } -} - VTKM_CONT ScopedRuntimeDeviceTracker::ScopedRuntimeDeviceTracker( vtkm::cont::DeviceAdapterId device, @@ -268,6 +272,14 @@ ScopedRuntimeDeviceTracker::ScopedRuntimeDeviceTracker( } } +VTKM_CONT ScopedRuntimeDeviceTracker::ScopedRuntimeDeviceTracker( + const std::function& abortChecker, + const vtkm::cont::RuntimeDeviceTracker& tracker) + : ScopedRuntimeDeviceTracker(tracker) +{ + this->SetAbortChecker(abortChecker); +} + VTKM_CONT ScopedRuntimeDeviceTracker::~ScopedRuntimeDeviceTracker() { diff --git a/vtkm/cont/RuntimeDeviceTracker.h b/vtkm/cont/RuntimeDeviceTracker.h index 5de7c22bc..c847ec92a 100644 --- a/vtkm/cont/RuntimeDeviceTracker.h +++ b/vtkm/cont/RuntimeDeviceTracker.h @@ -17,6 +17,7 @@ #include #include +#include #include namespace vtkm @@ -123,6 +124,18 @@ public: /// VTKM_CONT void CopyStateFrom(const vtkm::cont::RuntimeDeviceTracker& tracker); + ///@{ + /// \brief Set/Clear the abort checker functor. + /// + /// If set the abort checker functor is called by \c TryExecute before scheduling + /// a task on a device from the associated the thread. If the functor returns + /// \e true, an exception is thrown. + VTKM_CONT void SetAbortChecker(const std::function& func); + VTKM_CONT void ClearAbortChecker(); + ///@} + + VTKM_CONT bool CheckForAbortRequest() const; + VTKM_CONT void PrintSummary(std::ostream& out) const; private: @@ -149,82 +162,7 @@ private: void LogEnabledDevices() const; }; - -enum struct RuntimeDeviceTrackerMode -{ - Force, - Enable, - Disable -}; - -/// A class that can be used to determine or modify which device adapter -/// VTK-m algorithms should be run on. This class captures the state -/// of the per-thread device adapter and will revert any changes applied -/// during its lifetime on destruction. -/// -/// -struct VTKM_CONT_EXPORT ScopedRuntimeDeviceTracker : public vtkm::cont::RuntimeDeviceTracker -{ - /// Construct a ScopedRuntimeDeviceTracker where the state of the active devices - /// for the current thread are determined by the parameters to the constructor. - /// - /// 'Force' - /// - Force-Enable the provided single device adapter - /// - Force-Enable all device adapters when using vtkm::cont::DeviceAdaterTagAny - /// 'Enable' - /// - Enable the provided single device adapter if it was previously disabled - /// - Enable all device adapters that are currently disabled when using - /// vtkm::cont::DeviceAdaterTagAny - /// 'Disable' - /// - Disable the provided single device adapter - /// - Disable all device adapters when using vtkm::cont::DeviceAdaterTagAny - /// - /// Constructor is not thread safe - VTKM_CONT ScopedRuntimeDeviceTracker( - vtkm::cont::DeviceAdapterId device, - RuntimeDeviceTrackerMode mode = RuntimeDeviceTrackerMode::Force); - - /// Construct a ScopedRuntimeDeviceTracker associated with the thread - /// associated with the provided tracker. The active devices - /// for the current thread are determined by the parameters to the constructor. - /// - /// 'Force' - /// - Force-Enable the provided single device adapter - /// - Force-Enable all device adapters when using vtkm::cont::DeviceAdaterTagAny - /// 'Enable' - /// - Enable the provided single device adapter if it was previously disabled - /// - Enable all device adapters that are currently disabled when using - /// vtkm::cont::DeviceAdaterTagAny - /// 'Disable' - /// - Disable the provided single device adapter - /// - Disable all device adapters when using vtkm::cont::DeviceAdaterTagAny - /// - /// Any modifications to the ScopedRuntimeDeviceTracker will effect what - /// ever thread the \c tracker is associated with, which might not be - /// the thread which ScopedRuntimeDeviceTracker was constructed on. - /// - /// Constructor is not thread safe - VTKM_CONT ScopedRuntimeDeviceTracker(vtkm::cont::DeviceAdapterId device, - RuntimeDeviceTrackerMode mode, - const vtkm::cont::RuntimeDeviceTracker& tracker); - - /// Construct a ScopedRuntimeDeviceTracker associated with the thread - /// associated with the provided tracker. - /// - /// Any modifications to the ScopedRuntimeDeviceTracker will effect what - /// ever thread the \c tracker is associated with, which might not be - /// the thread which ScopedRuntimeDeviceTracker was constructed on. - /// - /// Constructor is not thread safe - VTKM_CONT ScopedRuntimeDeviceTracker(const vtkm::cont::RuntimeDeviceTracker& tracker); - - /// Destructor is not thread safe - VTKM_CONT ~ScopedRuntimeDeviceTracker(); - -private: - std::unique_ptr SavedState; -}; - +///---------------------------------------------------------------------------- /// \brief Get the \c RuntimeDeviceTracker for the current thread. /// /// Many features in VTK-m will attempt to run algorithms on the "best @@ -236,6 +174,66 @@ private: VTKM_CONT_EXPORT VTKM_CONT vtkm::cont::RuntimeDeviceTracker& GetRuntimeDeviceTracker(); + +enum struct RuntimeDeviceTrackerMode +{ + Force, + Enable, + Disable +}; + +///---------------------------------------------------------------------------- +/// A class to create a scoped runtime device tracker object. This object captures the state +/// of the per-thread device tracker and will revert any changes applied +/// during its lifetime on destruction. +/// +struct VTKM_CONT_EXPORT ScopedRuntimeDeviceTracker : public vtkm::cont::RuntimeDeviceTracker +{ + /// Construct a ScopedRuntimeDeviceTracker associated with the thread, + /// associated with the provided tracker (defaults to current thread's tracker). + /// + /// Any modifications to the ScopedRuntimeDeviceTracker will effect what + /// ever thread the \c tracker is associated with, which might not be + /// the thread on which the ScopedRuntimeDeviceTracker was constructed. + /// + /// Constructors are not thread safe + /// @{ + /// + VTKM_CONT ScopedRuntimeDeviceTracker( + const vtkm::cont::RuntimeDeviceTracker& tracker = GetRuntimeDeviceTracker()); + + /// Use this constructor to modify the state of the device adapters associated with + /// the provided tracker. Use \p mode with \p device as follows: + /// + /// 'Force' (default) + /// - Force-Enable the provided single device adapter + /// - Force-Enable all device adapters when using vtkm::cont::DeviceAdaterTagAny + /// 'Enable' + /// - Enable the provided single device adapter if it was previously disabled + /// - Enable all device adapters that are currently disabled when using + /// vtkm::cont::DeviceAdaterTagAny + /// 'Disable' + /// - Disable the provided single device adapter + /// - Disable all device adapters when using vtkm::cont::DeviceAdaterTagAny + /// + VTKM_CONT ScopedRuntimeDeviceTracker( + vtkm::cont::DeviceAdapterId device, + RuntimeDeviceTrackerMode mode = RuntimeDeviceTrackerMode::Force, + const vtkm::cont::RuntimeDeviceTracker& tracker = GetRuntimeDeviceTracker()); + + /// Use this constructor to set the abort checker functor for the provided tracker. + /// + VTKM_CONT ScopedRuntimeDeviceTracker( + const std::function& abortChecker, + const vtkm::cont::RuntimeDeviceTracker& tracker = GetRuntimeDeviceTracker()); + + /// Destructor is not thread safe + VTKM_CONT ~ScopedRuntimeDeviceTracker(); + +private: + std::unique_ptr SavedState; +}; + } } // namespace vtkm::cont diff --git a/vtkm/cont/TryExecute.cxx b/vtkm/cont/TryExecute.cxx index 8b2f9b37f..59503c818 100644 --- a/vtkm/cont/TryExecute.cxx +++ b/vtkm/cont/TryExecute.cxx @@ -12,6 +12,7 @@ #include #include #include +#include #include namespace vtkm @@ -55,6 +56,13 @@ VTKM_CONT_EXPORT void HandleTryExecuteException(vtkm::cont::DeviceAdapterId devi VTKM_LOG_TRYEXECUTE_FAIL("ErrorBadValue (" << e.GetMessage() << ")", functorName, deviceId); throw; } + catch (vtkm::cont::ErrorUserAbort& e) + { + VTKM_LOG_S(vtkm::cont::LogLevel::Info, + e.GetMessage() << " Aborting: " << functorName << ", on device " + << deviceId.GetName()); + throw; + } catch (vtkm::cont::Error& e) { VTKM_LOG_TRYEXECUTE_FAIL(e.GetMessage(), functorName, deviceId); diff --git a/vtkm/cont/TryExecute.h b/vtkm/cont/TryExecute.h index 78942afa2..f62138149 100644 --- a/vtkm/cont/TryExecute.h +++ b/vtkm/cont/TryExecute.h @@ -12,6 +12,7 @@ #include #include +#include #include #include @@ -40,6 +41,11 @@ inline bool TryExecuteIfValid(std::true_type, { try { + if (tracker.CheckForAbortRequest()) + { + throw vtkm::cont::ErrorUserAbort{}; + } + return f(tag, std::forward(args)...); } catch (...) diff --git a/vtkm/cont/testing/CMakeLists.txt b/vtkm/cont/testing/CMakeLists.txt index e5711e5ea..ef4c4e566 100644 --- a/vtkm/cont/testing/CMakeLists.txt +++ b/vtkm/cont/testing/CMakeLists.txt @@ -122,6 +122,12 @@ if(TARGET vtkm_filter_field_conversion) ) endif() +if(TARGET vtkm_filter_contour) + list(APPEND unit_tests_device + UnitTestAbort.cxx + ) +endif() + vtkm_unit_tests(SOURCES ${unit_tests} DEVICE_SOURCES ${unit_tests_device}) #add distributed tests i.e.test to run with MPI diff --git a/vtkm/cont/testing/UnitTestAbort.cxx b/vtkm/cont/testing/UnitTestAbort.cxx new file mode 100644 index 000000000..a9cd50335 --- /dev/null +++ b/vtkm/cont/testing/UnitTestAbort.cxx @@ -0,0 +1,105 @@ +//============================================================================ +// Copyright (c) Kitware, Inc. +// All rights reserved. +// See LICENSE.txt for details. +// +// This software is distributed WITHOUT ANY WARRANTY; without even +// the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR +// PURPOSE. See the above copyright notice for more information. +//============================================================================ + +#include +#include +#include +#include +#include + +#include + +namespace +{ + +// A function that checks for abort request. +// This function will be called by `TryExecute` befaure lauching a device task +// to check if abort has been requested. +// For this test case, we are using a simple logic of returning true for the +// `abortAt`th check. +// If this test is failing, one of the things to check would be to see if the +// `Contour` filter has changed such that it no longer has atleast `abortAt` +// task invocations. +bool ShouldAbort() +{ + static int abortCheckCounter = 0; + static constexpr int abortAt = 5; + if (++abortCheckCounter >= abortAt) + { + std::cout << "Abort check " << abortCheckCounter << ": true\n"; + return true; + } + + std::cout << "Abort check " << abortCheckCounter << ": false\n"; + return false; +} + +int TestAbort() +{ + vtkm::source::Wavelet wavelet; + wavelet.SetExtent(vtkm::Id3(-15), vtkm::Id3(16)); + auto input = wavelet.Execute(); + + auto range = input.GetField("RTData").GetRange().ReadPortal().Get(0); + std::vector isovals; + static constexpr int numDivs = 5; + for (int i = 1; i < numDivs - 1; ++i) + { + auto v = range.Min + + (static_cast(i) * + ((range.Max - range.Min) / static_cast(numDivs))); + isovals.push_back(v); + } + + vtkm::filter::contour::Contour contour; + contour.SetActiveField("RTData"); + contour.SetIsoValues(isovals); + + // First we will run the filter with the abort function set + std::cout << "Run #1 with the abort function set\n"; + try + { + vtkm::cont::ScopedRuntimeDeviceTracker tracker(ShouldAbort); + + auto result = contour.Execute(input); + + // execution shouldn't reach here + VTKM_TEST_FAIL("Error: filter execution was not aborted. Result: ", + result.GetNumberOfPoints(), + " points and ", + result.GetNumberOfCells(), + " triangles"); + } + catch (const vtkm::cont::ErrorUserAbort& e) + { + std::cout << "Execution was successfully aborted\n"; + } + + // Now run the filter without the abort function + std::cout << "Run #2 without the abort function set\n"; + try + { + auto result = contour.Execute(input); + std::cout << "Success: filter execution was not aborted. Result: " << result.GetNumberOfPoints() + << " points and " << result.GetNumberOfCells() << " triangles\n"; + } + catch (const vtkm::cont::ErrorUserAbort& e) + { + VTKM_TEST_FAIL("Execution was unexpectedly aborted"); + } + + return 0; +} +} // anon namespace + +int UnitTestAbort(int argc, char* argv[]) +{ + return vtkm::cont::testing::Testing::Run(TestAbort, argc, argv); +} diff --git a/vtkm/cont/vtkm.module b/vtkm/cont/vtkm.module index 33396e7ae..587729dce 100644 --- a/vtkm/cont/vtkm.module +++ b/vtkm/cont/vtkm.module @@ -9,6 +9,7 @@ DEPENDS OPTIONAL_DEPENDS vtkm_loguru TEST_OPTIONAL_DEPENDS + vtkm_filter_contour vtkm_filter_field_conversion TEST_DEPENDS vtkm_source