mirror of
https://gitlab.kitware.com/vtk/vtk-m
synced 2024-09-16 17:22:55 +00:00
VTK-m now supports case-insensitive construction of devices from strings.
Previously you had to exactly match the case of a device adapter's name to construct it, which was a source of lots of problems ( OpenMP versus OPENMP, CUDA or Cuda ). Now `vtkm::cont::make_DeviceAdapterId` and `vtkm::cont::RuntimeDeviceTracker` support case-insensitive device construction.
This commit is contained in:
parent
0ae31eb637
commit
ce95b8f788
14
docs/changelog/case-insensitive-device-from-string.md
Normal file
14
docs/changelog/case-insensitive-device-from-string.md
Normal file
@ -0,0 +1,14 @@
|
|||||||
|
# VTK-m `vtkm::cont::DeviceAdapterId` construction from string are now case-insensitive
|
||||||
|
|
||||||
|
You can now construct a `vtkm::cont::DeviceAdapterId` from a string no matter
|
||||||
|
the case of it. The following all will construct the same `vtkm::cont::DeviceAdapterId`.
|
||||||
|
|
||||||
|
```cpp
|
||||||
|
vtkm::cont::DeviceAdapterId id1 = vtkm::cont::make_DeviceAdapterId("cuda");
|
||||||
|
vtkm::cont::DeviceAdapterId id2 = vtkm::cont::make_DeviceAdapterId("CUDA");
|
||||||
|
vtkm::cont::DeviceAdapterId id3 = vtkm::cont::make_DeviceAdapterId("Cuda");
|
||||||
|
|
||||||
|
auto tracker = vtkm::cont::GetGlobalRuntimeDeviceTracker();
|
||||||
|
vtkm::cont::DeviceAdapterId id4 = tracker.GetDeviceAdapterId("cuda");
|
||||||
|
vtkm::cont::DeviceAdapterId id5 = tracker.GetDeviceAdapterId("CUDA");
|
||||||
|
vtkm::cont::DeviceAdapterId id6 = tracker.GetDeviceAdapterId("Cuda");
|
@ -38,6 +38,7 @@
|
|||||||
|
|
||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
|
#include <cctype> //for tolower
|
||||||
#include <map>
|
#include <map>
|
||||||
#include <mutex>
|
#include <mutex>
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
@ -49,22 +50,33 @@ namespace
|
|||||||
struct VTKM_NEVER_EXPORT GetDeviceNameFunctor
|
struct VTKM_NEVER_EXPORT GetDeviceNameFunctor
|
||||||
{
|
{
|
||||||
vtkm::cont::DeviceAdapterNameType* Names;
|
vtkm::cont::DeviceAdapterNameType* Names;
|
||||||
|
vtkm::cont::DeviceAdapterNameType* LowerCaseNames;
|
||||||
|
|
||||||
VTKM_CONT
|
VTKM_CONT
|
||||||
GetDeviceNameFunctor(vtkm::cont::DeviceAdapterNameType* names)
|
GetDeviceNameFunctor(vtkm::cont::DeviceAdapterNameType* names,
|
||||||
|
vtkm::cont::DeviceAdapterNameType* lower)
|
||||||
: Names(names)
|
: Names(names)
|
||||||
|
, LowerCaseNames(lower)
|
||||||
{
|
{
|
||||||
std::fill_n(this->Names, VTKM_MAX_DEVICE_ADAPTER_ID, "InvalidDeviceId");
|
std::fill_n(this->Names, VTKM_MAX_DEVICE_ADAPTER_ID, "InvalidDeviceId");
|
||||||
|
std::fill_n(this->LowerCaseNames, VTKM_MAX_DEVICE_ADAPTER_ID, "invaliddeviceid");
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename Device>
|
template <typename Device>
|
||||||
VTKM_CONT void operator()(Device device)
|
VTKM_CONT void operator()(Device device)
|
||||||
{
|
{
|
||||||
|
auto lowerCaseFunc = [](char c) {
|
||||||
|
return static_cast<char>(std::tolower(static_cast<unsigned char>(c)));
|
||||||
|
};
|
||||||
|
|
||||||
auto id = device.GetValue();
|
auto id = device.GetValue();
|
||||||
|
|
||||||
if (id > 0 && id < VTKM_MAX_DEVICE_ADAPTER_ID)
|
if (id > 0 && id < VTKM_MAX_DEVICE_ADAPTER_ID)
|
||||||
{
|
{
|
||||||
this->Names[id] = vtkm::cont::DeviceAdapterTraits<Device>::GetName();
|
auto name = vtkm::cont::DeviceAdapterTraits<Device>::GetName();
|
||||||
|
this->Names[id] = name;
|
||||||
|
std::transform(name.begin(), name.end(), name.begin(), lowerCaseFunc);
|
||||||
|
this->LowerCaseNames[id] = name;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -87,6 +99,7 @@ struct RuntimeDeviceTrackerInternals
|
|||||||
{
|
{
|
||||||
bool RuntimeValid[VTKM_MAX_DEVICE_ADAPTER_ID];
|
bool RuntimeValid[VTKM_MAX_DEVICE_ADAPTER_ID];
|
||||||
DeviceAdapterNameType DeviceNames[VTKM_MAX_DEVICE_ADAPTER_ID];
|
DeviceAdapterNameType DeviceNames[VTKM_MAX_DEVICE_ADAPTER_ID];
|
||||||
|
DeviceAdapterNameType LowerCaseDeviceNames[VTKM_MAX_DEVICE_ADAPTER_ID];
|
||||||
};
|
};
|
||||||
|
|
||||||
struct RuntimeDeviceTrackerFunctor
|
struct RuntimeDeviceTrackerFunctor
|
||||||
@ -107,7 +120,7 @@ VTKM_CONT
|
|||||||
RuntimeDeviceTracker::RuntimeDeviceTracker()
|
RuntimeDeviceTracker::RuntimeDeviceTracker()
|
||||||
: Internals(std::make_shared<detail::RuntimeDeviceTrackerInternals>())
|
: Internals(std::make_shared<detail::RuntimeDeviceTrackerInternals>())
|
||||||
{
|
{
|
||||||
GetDeviceNameFunctor functor(this->Internals->DeviceNames);
|
GetDeviceNameFunctor functor(this->Internals->DeviceNames, this->Internals->LowerCaseDeviceNames);
|
||||||
vtkm::ListForEach(functor, VTKM_DEFAULT_DEVICE_ADAPTER_LIST_TAG());
|
vtkm::ListForEach(functor, VTKM_DEFAULT_DEVICE_ADAPTER_LIST_TAG());
|
||||||
|
|
||||||
this->Reset();
|
this->Reset();
|
||||||
@ -197,6 +210,9 @@ RuntimeDeviceTracker::RuntimeDeviceTracker(
|
|||||||
{
|
{
|
||||||
std::copy_n(internals->RuntimeValid, VTKM_MAX_DEVICE_ADAPTER_ID, this->Internals->RuntimeValid);
|
std::copy_n(internals->RuntimeValid, VTKM_MAX_DEVICE_ADAPTER_ID, this->Internals->RuntimeValid);
|
||||||
std::copy_n(internals->DeviceNames, VTKM_MAX_DEVICE_ADAPTER_ID, this->Internals->DeviceNames);
|
std::copy_n(internals->DeviceNames, VTKM_MAX_DEVICE_ADAPTER_ID, this->Internals->DeviceNames);
|
||||||
|
std::copy_n(internals->LowerCaseDeviceNames,
|
||||||
|
VTKM_MAX_DEVICE_ADAPTER_ID,
|
||||||
|
this->Internals->LowerCaseDeviceNames);
|
||||||
}
|
}
|
||||||
|
|
||||||
VTKM_CONT
|
VTKM_CONT
|
||||||
@ -265,22 +281,30 @@ DeviceAdapterNameType RuntimeDeviceTracker::GetDeviceName(DeviceAdapterId device
|
|||||||
VTKM_CONT
|
VTKM_CONT
|
||||||
DeviceAdapterId RuntimeDeviceTracker::GetDeviceAdapterId(DeviceAdapterNameType name) const
|
DeviceAdapterId RuntimeDeviceTracker::GetDeviceAdapterId(DeviceAdapterNameType name) const
|
||||||
{
|
{
|
||||||
if (name == "Any")
|
// The GetDeviceAdapterId call is case-insensitive so transform the name to be lower case
|
||||||
|
// as that is how we cache the case-insensitive version.
|
||||||
|
auto lowerCaseFunc = [](char c) {
|
||||||
|
return static_cast<char>(std::tolower(static_cast<unsigned char>(c)));
|
||||||
|
};
|
||||||
|
std::transform(name.begin(), name.end(), name.begin(), lowerCaseFunc);
|
||||||
|
|
||||||
|
//lower-case the name here
|
||||||
|
if (name == "any")
|
||||||
{
|
{
|
||||||
return vtkm::cont::DeviceAdapterTagAny{};
|
return vtkm::cont::DeviceAdapterTagAny{};
|
||||||
}
|
}
|
||||||
else if (name == "Error")
|
else if (name == "error")
|
||||||
{
|
{
|
||||||
return vtkm::cont::DeviceAdapterTagError{};
|
return vtkm::cont::DeviceAdapterTagError{};
|
||||||
}
|
}
|
||||||
else if (name == "Undefined")
|
else if (name == "undefined")
|
||||||
{
|
{
|
||||||
return vtkm::cont::DeviceAdapterTagUndefined{};
|
return vtkm::cont::DeviceAdapterTagUndefined{};
|
||||||
}
|
}
|
||||||
|
|
||||||
for (vtkm::Int8 id = 0; id < VTKM_MAX_DEVICE_ADAPTER_ID; ++id)
|
for (vtkm::Int8 id = 0; id < VTKM_MAX_DEVICE_ADAPTER_ID; ++id)
|
||||||
{
|
{
|
||||||
if (name == this->Internals->DeviceNames[id])
|
if (name == this->Internals->LowerCaseDeviceNames[id])
|
||||||
{
|
{
|
||||||
return vtkm::cont::make_DeviceAdapterId(id);
|
return vtkm::cont::make_DeviceAdapterId(id);
|
||||||
}
|
}
|
||||||
|
@ -209,7 +209,8 @@ public:
|
|||||||
DeviceAdapterNameType GetDeviceName(DeviceAdapterId id) const;
|
DeviceAdapterNameType GetDeviceName(DeviceAdapterId id) const;
|
||||||
|
|
||||||
/// Returns the id corresponding to the device adapter name. If @a name is
|
/// Returns the id corresponding to the device adapter name. If @a name is
|
||||||
/// not recognized, DeviceAdapterTagUndefined is returned.
|
/// not recognized, DeviceAdapterTagUndefined is returned. Queries for a
|
||||||
|
/// name are all case-insensitive.
|
||||||
VTKM_CONT_EXPORT
|
VTKM_CONT_EXPORT
|
||||||
VTKM_CONT
|
VTKM_CONT
|
||||||
DeviceAdapterId GetDeviceAdapterId(DeviceAdapterNameType name) const;
|
DeviceAdapterId GetDeviceAdapterId(DeviceAdapterNameType name) const;
|
||||||
|
@ -64,7 +64,6 @@ struct DeviceAdapterId
|
|||||||
|
|
||||||
protected:
|
protected:
|
||||||
friend DeviceAdapterId make_DeviceAdapterId(vtkm::Int8 id);
|
friend DeviceAdapterId make_DeviceAdapterId(vtkm::Int8 id);
|
||||||
friend DeviceAdapterId make_DeviceAdapterIdFromName(const std::string& name);
|
|
||||||
|
|
||||||
constexpr explicit DeviceAdapterId(vtkm::Int8 id)
|
constexpr explicit DeviceAdapterId(vtkm::Int8 id)
|
||||||
: Value(id)
|
: Value(id)
|
||||||
@ -75,9 +74,19 @@ private:
|
|||||||
vtkm::Int8 Value;
|
vtkm::Int8 Value;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
/// Construct a device adapter id from a runtime string
|
||||||
|
/// The string is case-insensitive. So CUDA will be selected with 'cuda', 'Cuda', or 'CUDA'.
|
||||||
VTKM_CONT_EXPORT
|
VTKM_CONT_EXPORT
|
||||||
DeviceAdapterId make_DeviceAdapterId(const DeviceAdapterNameType& name);
|
DeviceAdapterId make_DeviceAdapterId(const DeviceAdapterNameType& name);
|
||||||
|
|
||||||
|
/// Construct a device adapter id a vtkm::Int8.
|
||||||
|
/// The mapping of integer value to devices are:
|
||||||
|
///
|
||||||
|
/// DeviceAdapterTagSerial == 1
|
||||||
|
/// DeviceAdapterTagCuda == 2
|
||||||
|
/// DeviceAdapterTagTBB == 3
|
||||||
|
/// DeviceAdapterTagOpenMP == 4
|
||||||
|
///
|
||||||
inline DeviceAdapterId make_DeviceAdapterId(vtkm::Int8 id)
|
inline DeviceAdapterId make_DeviceAdapterId(vtkm::Int8 id)
|
||||||
{
|
{
|
||||||
return DeviceAdapterId(id);
|
return DeviceAdapterId(id);
|
||||||
|
@ -28,6 +28,8 @@
|
|||||||
|
|
||||||
#include <vtkm/cont/testing/Testing.h>
|
#include <vtkm/cont/testing/Testing.h>
|
||||||
|
|
||||||
|
#include <cctype> //for tolower
|
||||||
|
|
||||||
namespace
|
namespace
|
||||||
{
|
{
|
||||||
|
|
||||||
@ -53,12 +55,37 @@ void TestName(const std::string& name, Tag tag, vtkm::cont::DeviceAdapterId id)
|
|||||||
<< "\t" << tracker.GetDeviceName(id) << "\n"
|
<< "\t" << tracker.GetDeviceName(id) << "\n"
|
||||||
<< "\t" << tracker.GetDeviceName(tag) << "\n";
|
<< "\t" << tracker.GetDeviceName(tag) << "\n";
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
VTKM_TEST_ASSERT(id.GetName() == name, "Id::GetName() failed.");
|
VTKM_TEST_ASSERT(id.GetName() == name, "Id::GetName() failed.");
|
||||||
VTKM_TEST_ASSERT(tag.GetName() == name, "Tag::GetName() failed.");
|
VTKM_TEST_ASSERT(tag.GetName() == name, "Tag::GetName() failed.");
|
||||||
|
VTKM_TEST_ASSERT(vtkm::cont::make_DeviceAdapterId(id.GetValue()) == id,
|
||||||
|
"make_DeviceAdapterId(int8) failed");
|
||||||
VTKM_TEST_ASSERT(tracker.GetDeviceName(id) == name, "RTDeviceTracker::GetDeviceName(Id) failed.");
|
VTKM_TEST_ASSERT(tracker.GetDeviceName(id) == name, "RTDeviceTracker::GetDeviceName(Id) failed.");
|
||||||
VTKM_TEST_ASSERT(tracker.GetDeviceName(tag) == name,
|
VTKM_TEST_ASSERT(tracker.GetDeviceName(tag) == name,
|
||||||
"RTDeviceTracker::GetDeviceName(Tag) failed.");
|
"RTDeviceTracker::GetDeviceName(Tag) failed.");
|
||||||
|
|
||||||
|
//check going from name to device id
|
||||||
|
auto lowerCaseFunc = [](char c) {
|
||||||
|
return static_cast<char>(std::tolower(static_cast<unsigned char>(c)));
|
||||||
|
};
|
||||||
|
|
||||||
|
auto upperCaseFunc = [](char c) {
|
||||||
|
return static_cast<char>(std::toupper(static_cast<unsigned char>(c)));
|
||||||
|
};
|
||||||
|
|
||||||
|
if (id.IsValueValid())
|
||||||
|
{ //only test make_DeviceAdapterId with valid device ids
|
||||||
|
VTKM_TEST_ASSERT(
|
||||||
|
vtkm::cont::make_DeviceAdapterId(name) == id, "make_DeviceAdapterId(", name, ") failed");
|
||||||
|
|
||||||
|
std::string casedName = name;
|
||||||
|
std::transform(casedName.begin(), casedName.end(), casedName.begin(), lowerCaseFunc);
|
||||||
|
VTKM_TEST_ASSERT(
|
||||||
|
vtkm::cont::make_DeviceAdapterId(casedName) == id, "make_DeviceAdapterId(", name, ") failed");
|
||||||
|
|
||||||
|
std::transform(casedName.begin(), casedName.end(), casedName.begin(), upperCaseFunc);
|
||||||
|
VTKM_TEST_ASSERT(
|
||||||
|
vtkm::cont::make_DeviceAdapterId(casedName) == id, "make_DeviceAdapterId(", name, ") failed");
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void TestNames()
|
void TestNames()
|
||||||
|
Loading…
Reference in New Issue
Block a user