Merge topic 'universal-dot-product-filter'
0c47645ba Remove DotProduct from worklet directory d76929a78 Allow dot product to work with any field type Acked-by: Kitware Robot <kwrobot@kitware.com> Acked-by: Li-Ta Lo <ollie@lanl.gov> Acked-by: Kenneth Moreland <morelandkd@ornl.gov> Merge-request: !2682
This commit is contained in:
commit
ff43a2efa5
@ -8,33 +8,59 @@
|
|||||||
// PURPOSE. See the above copyright notice for more information.
|
// PURPOSE. See the above copyright notice for more information.
|
||||||
//============================================================================
|
//============================================================================
|
||||||
|
|
||||||
|
#include <vtkm/cont/ErrorFilterExecution.h>
|
||||||
#include <vtkm/filter/vector_calculus/DotProduct.h>
|
#include <vtkm/filter/vector_calculus/DotProduct.h>
|
||||||
#include <vtkm/worklet/WorkletMapField.h>
|
#include <vtkm/worklet/WorkletMapField.h>
|
||||||
|
|
||||||
namespace // anonymous namespace making worklet::DotProduct internal to this .cxx
|
namespace // anonymous namespace making worklet::DotProduct internal to this .cxx
|
||||||
{
|
{
|
||||||
namespace worklet
|
|
||||||
|
struct DotProductWorklet : vtkm::worklet::WorkletMapField
|
||||||
{
|
{
|
||||||
class DotProduct : public vtkm::worklet::WorkletMapField
|
|
||||||
{
|
|
||||||
public:
|
|
||||||
using ControlSignature = void(FieldIn, FieldIn, FieldOut);
|
using ControlSignature = void(FieldIn, FieldIn, FieldOut);
|
||||||
|
|
||||||
template <typename T, vtkm::IdComponent Size>
|
template <typename T1, typename T2, typename T3>
|
||||||
VTKM_EXEC void operator()(const vtkm::Vec<T, Size>& v1,
|
VTKM_EXEC void operator()(const T1& v1, const T2& v2, T3& outValue) const
|
||||||
const vtkm::Vec<T, Size>& v2,
|
|
||||||
T& outValue) const
|
|
||||||
{
|
{
|
||||||
outValue = static_cast<T>(vtkm::Dot(v1, v2));
|
VTKM_ASSERT(v1.GetNumberOfComponents() == v2.GetNumberOfComponents());
|
||||||
}
|
outValue = v1[0] * v2[0];
|
||||||
|
for (vtkm::IdComponent i = 1; i < v1.GetNumberOfComponents(); ++i)
|
||||||
template <typename T>
|
{
|
||||||
VTKM_EXEC void operator()(T s1, T s2, T& outValue) const
|
outValue += v1[i] * v2[i];
|
||||||
{
|
}
|
||||||
outValue = static_cast<T>(s1 * s2);
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
} // namespace worklet
|
|
||||||
|
template <typename PrimaryArrayType>
|
||||||
|
vtkm::cont::UnknownArrayHandle DoDotProduct(const PrimaryArrayType& primaryArray,
|
||||||
|
const vtkm::cont::Field& secondaryField)
|
||||||
|
{
|
||||||
|
using T = typename PrimaryArrayType::ValueType::ComponentType;
|
||||||
|
|
||||||
|
vtkm::cont::Invoker invoke;
|
||||||
|
vtkm::cont::ArrayHandle<T> outputArray;
|
||||||
|
|
||||||
|
if (secondaryField.GetData().IsBaseComponentType<T>())
|
||||||
|
{
|
||||||
|
invoke(DotProductWorklet{},
|
||||||
|
primaryArray,
|
||||||
|
secondaryField.GetData().ExtractArrayFromComponents<T>(),
|
||||||
|
outputArray);
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
// Data types of primary and secondary array do not match. Rather than try to replicate every
|
||||||
|
// possibility, get the secondary array as a FloatDefault.
|
||||||
|
vtkm::cont::UnknownArrayHandle castSecondaryArray = secondaryField.GetDataAsDefaultFloat();
|
||||||
|
invoke(DotProductWorklet{},
|
||||||
|
primaryArray,
|
||||||
|
castSecondaryArray.ExtractArrayFromComponents<vtkm::FloatDefault>(),
|
||||||
|
outputArray);
|
||||||
|
}
|
||||||
|
|
||||||
|
return outputArray;
|
||||||
|
}
|
||||||
|
|
||||||
} // anonymous namespace
|
} // anonymous namespace
|
||||||
|
|
||||||
namespace vtkm
|
namespace vtkm
|
||||||
@ -51,37 +77,39 @@ VTKM_CONT DotProduct::DotProduct()
|
|||||||
|
|
||||||
VTKM_CONT vtkm::cont::DataSet DotProduct::DoExecute(const vtkm::cont::DataSet& inDataSet)
|
VTKM_CONT vtkm::cont::DataSet DotProduct::DoExecute(const vtkm::cont::DataSet& inDataSet)
|
||||||
{
|
{
|
||||||
const auto& primaryArray = this->GetFieldFromDataSet(inDataSet).GetData();
|
vtkm::cont::Field primaryField = this->GetFieldFromDataSet(0, inDataSet);
|
||||||
|
vtkm::cont::UnknownArrayHandle primaryArray = primaryField.GetData();
|
||||||
|
|
||||||
|
vtkm::cont::Field secondaryField = this->GetFieldFromDataSet(1, inDataSet);
|
||||||
|
|
||||||
|
if (primaryArray.GetNumberOfComponentsFlat() !=
|
||||||
|
secondaryField.GetData().GetNumberOfComponentsFlat())
|
||||||
|
{
|
||||||
|
throw vtkm::cont::ErrorFilterExecution(
|
||||||
|
"Primary and secondary arrays of DotProduct filter have different number of components.");
|
||||||
|
}
|
||||||
|
|
||||||
vtkm::cont::UnknownArrayHandle outArray;
|
vtkm::cont::UnknownArrayHandle outArray;
|
||||||
|
|
||||||
// We are using a C++14 auto lambda here. The advantage over a Functor is obvious, we don't
|
if (primaryArray.IsBaseComponentType<vtkm::Float32>())
|
||||||
// need to explicitly pass filter, input/output DataSets etc. thus reduce the impact to
|
{
|
||||||
// the legacy code. The lambda can also access the private part of the filter thus reducing
|
outArray =
|
||||||
// filter's public interface profile. CastAndCall tries to cast primaryArray of unknown value
|
DoDotProduct(primaryArray.ExtractArrayFromComponents<vtkm::Float32>(), secondaryField);
|
||||||
// type and storage to a concrete ArrayHandle<T, S> with T from the `TypeList` and S from
|
}
|
||||||
// `StorageList`. It then passes the concrete array to the lambda as the first argument.
|
else if (primaryArray.IsBaseComponentType<vtkm::Float64>())
|
||||||
// We can later recover the concrete ValueType, T, from the concrete array.
|
{
|
||||||
auto ResolveType = [&, this](const auto& concrete) {
|
outArray =
|
||||||
// use std::decay to remove const ref from the decltype of concrete.
|
DoDotProduct(primaryArray.ExtractArrayFromComponents<vtkm::Float64>(), secondaryField);
|
||||||
using T = typename std::decay_t<decltype(concrete)>::ValueType;
|
}
|
||||||
const auto& secondaryField = this->GetFieldFromDataSet(1, inDataSet);
|
else
|
||||||
vtkm::cont::UnknownArrayHandle secondary = vtkm::cont::ArrayHandle<T>{};
|
{
|
||||||
secondary.CopyShallowIfPossible(secondaryField.GetData());
|
primaryArray = primaryField.GetDataAsDefaultFloat();
|
||||||
|
outArray =
|
||||||
|
DoDotProduct(primaryArray.ExtractArrayFromComponents<vtkm::FloatDefault>(), secondaryField);
|
||||||
|
}
|
||||||
|
|
||||||
vtkm::cont::ArrayHandle<typename vtkm::VecTraits<T>::ComponentType> result;
|
vtkm::cont::DataSet outDataSet;
|
||||||
this->Invoke(::worklet::DotProduct{},
|
outDataSet.CopyStructure(inDataSet);
|
||||||
concrete,
|
|
||||||
secondary.template AsArrayHandle<vtkm::cont::ArrayHandle<T>>(),
|
|
||||||
result);
|
|
||||||
outArray = result;
|
|
||||||
};
|
|
||||||
|
|
||||||
primaryArray
|
|
||||||
.CastAndCallForTypesWithFloatFallback<VTKM_DEFAULT_TYPE_LIST, VTKM_DEFAULT_STORAGE_LIST>(
|
|
||||||
ResolveType);
|
|
||||||
|
|
||||||
vtkm::cont::DataSet outDataSet = inDataSet; // copy
|
|
||||||
outDataSet.AddField({ this->GetOutputFieldName(),
|
outDataSet.AddField({ this->GetOutputFieldName(),
|
||||||
this->GetFieldFromDataSet(inDataSet).GetAssociation(),
|
this->GetFieldFromDataSet(inDataSet).GetAssociation(),
|
||||||
outArray });
|
outArray });
|
||||||
|
@ -26,7 +26,6 @@ set(headers
|
|||||||
DispatcherCellNeighborhood.h
|
DispatcherCellNeighborhood.h
|
||||||
DispatcherPointNeighborhood.h
|
DispatcherPointNeighborhood.h
|
||||||
DispatcherReduceByKey.h
|
DispatcherReduceByKey.h
|
||||||
DotProduct.h
|
|
||||||
FieldStatistics.h
|
FieldStatistics.h
|
||||||
Gradient.h
|
Gradient.h
|
||||||
ImageDifference.h
|
ImageDifference.h
|
||||||
|
@ -1,45 +0,0 @@
|
|||||||
//============================================================================
|
|
||||||
// Copyright (c) Kitware, Inc.
|
|
||||||
// All rights reserved.
|
|
||||||
// See LICENSE.txt for details.
|
|
||||||
//
|
|
||||||
// This software is distributed WITHOUT ANY WARRANTY; without even
|
|
||||||
// the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
|
|
||||||
// PURPOSE. See the above copyright notice for more information.
|
|
||||||
//============================================================================
|
|
||||||
#ifndef vtk_m_worklet_DotProduct_h
|
|
||||||
#define vtk_m_worklet_DotProduct_h
|
|
||||||
|
|
||||||
#include <vtkm/worklet/WorkletMapField.h>
|
|
||||||
|
|
||||||
#include <vtkm/Math.h>
|
|
||||||
#include <vtkm/VectorAnalysis.h>
|
|
||||||
|
|
||||||
namespace vtkm
|
|
||||||
{
|
|
||||||
namespace worklet
|
|
||||||
{
|
|
||||||
|
|
||||||
class DotProduct : public vtkm::worklet::WorkletMapField
|
|
||||||
{
|
|
||||||
public:
|
|
||||||
using ControlSignature = void(FieldIn, FieldIn, FieldOut);
|
|
||||||
|
|
||||||
template <typename T, vtkm::IdComponent Size>
|
|
||||||
VTKM_EXEC void operator()(const vtkm::Vec<T, Size>& v1,
|
|
||||||
const vtkm::Vec<T, Size>& v2,
|
|
||||||
T& outValue) const
|
|
||||||
{
|
|
||||||
outValue = static_cast<T>(vtkm::Dot(v1, v2));
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
VTKM_EXEC void operator()(T s1, T s2, T& outValue) const
|
|
||||||
{
|
|
||||||
outValue = static_cast<T>(s1 * s2);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
}
|
|
||||||
} // namespace vtkm::worklet
|
|
||||||
|
|
||||||
#endif // vtk_m_worklet_Normalize_h
|
|
@ -31,7 +31,6 @@ set(unit_tests
|
|||||||
UnitTestCosmoTools.cxx
|
UnitTestCosmoTools.cxx
|
||||||
UnitTestCrossProduct.cxx
|
UnitTestCrossProduct.cxx
|
||||||
UnitTestDescriptiveStatistics.cxx
|
UnitTestDescriptiveStatistics.cxx
|
||||||
UnitTestDotProduct.cxx
|
|
||||||
UnitTestFieldStatistics.cxx
|
UnitTestFieldStatistics.cxx
|
||||||
UnitTestGraphConnectivity.cxx
|
UnitTestGraphConnectivity.cxx
|
||||||
UnitTestInnerJoin.cxx
|
UnitTestInnerJoin.cxx
|
||||||
|
@ -1,105 +0,0 @@
|
|||||||
//============================================================================
|
|
||||||
// Copyright (c) Kitware, Inc.
|
|
||||||
// All rights reserved.
|
|
||||||
// See LICENSE.txt for details.
|
|
||||||
//
|
|
||||||
// This software is distributed WITHOUT ANY WARRANTY; without even
|
|
||||||
// the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
|
|
||||||
// PURPOSE. See the above copyright notice for more information.
|
|
||||||
//============================================================================
|
|
||||||
|
|
||||||
#include <vtkm/worklet/DispatcherMapField.h>
|
|
||||||
#include <vtkm/worklet/DotProduct.h>
|
|
||||||
|
|
||||||
#include <vtkm/cont/testing/Testing.h>
|
|
||||||
|
|
||||||
namespace
|
|
||||||
{
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
T normalizedVector(T v)
|
|
||||||
{
|
|
||||||
T vN = vtkm::Normal(v);
|
|
||||||
return vN;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
void createVectors(std::vector<vtkm::Vec<T, 3>>& vecs1,
|
|
||||||
std::vector<vtkm::Vec<T, 3>>& vecs2,
|
|
||||||
std::vector<T>& result)
|
|
||||||
{
|
|
||||||
vecs1.push_back(normalizedVector(vtkm::make_Vec(T(1), T(0), T(0))));
|
|
||||||
vecs2.push_back(normalizedVector(vtkm::make_Vec(T(1), T(0), T(0))));
|
|
||||||
result.push_back(1);
|
|
||||||
|
|
||||||
vecs1.push_back(normalizedVector(vtkm::make_Vec(T(1), T(0), T(0))));
|
|
||||||
vecs2.push_back(normalizedVector(vtkm::make_Vec(T(-1), T(0), T(0))));
|
|
||||||
result.push_back(-1);
|
|
||||||
|
|
||||||
vecs1.push_back(normalizedVector(vtkm::make_Vec(T(1), T(0), T(0))));
|
|
||||||
vecs2.push_back(normalizedVector(vtkm::make_Vec(T(0), T(1), T(0))));
|
|
||||||
result.push_back(0);
|
|
||||||
|
|
||||||
vecs1.push_back(normalizedVector(vtkm::make_Vec(T(1), T(0), T(0))));
|
|
||||||
vecs2.push_back(normalizedVector(vtkm::make_Vec(T(0), T(-1), T(0))));
|
|
||||||
result.push_back(0);
|
|
||||||
|
|
||||||
vecs1.push_back(normalizedVector(vtkm::make_Vec(T(1), T(0), T(0))));
|
|
||||||
vecs2.push_back(normalizedVector(vtkm::make_Vec(T(1), T(1), T(0))));
|
|
||||||
result.push_back(T(1.0 / vtkm::Sqrt(2.0)));
|
|
||||||
|
|
||||||
vecs1.push_back(normalizedVector(vtkm::make_Vec(T(1), T(1), T(0))));
|
|
||||||
vecs2.push_back(normalizedVector(vtkm::make_Vec(T(1), T(0), T(0))));
|
|
||||||
result.push_back(T(1.0 / vtkm::Sqrt(2.0)));
|
|
||||||
|
|
||||||
vecs1.push_back(normalizedVector(vtkm::make_Vec(T(-1), T(0), T(0))));
|
|
||||||
vecs2.push_back(normalizedVector(vtkm::make_Vec(T(1), T(1), T(0))));
|
|
||||||
result.push_back(-T(1.0 / vtkm::Sqrt(2.0)));
|
|
||||||
|
|
||||||
vecs1.push_back(normalizedVector(vtkm::make_Vec(T(0), T(1), T(0))));
|
|
||||||
vecs2.push_back(normalizedVector(vtkm::make_Vec(T(1), T(1), T(0))));
|
|
||||||
result.push_back(T(1.0 / vtkm::Sqrt(2.0)));
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
void TestDotProduct()
|
|
||||||
{
|
|
||||||
std::vector<vtkm::Vec<T, 3>> inputVecs1, inputVecs2;
|
|
||||||
std::vector<T> answer;
|
|
||||||
createVectors(inputVecs1, inputVecs2, answer);
|
|
||||||
|
|
||||||
vtkm::cont::ArrayHandle<vtkm::Vec<T, 3>> inputArray1, inputArray2;
|
|
||||||
vtkm::cont::ArrayHandle<T> outputArray;
|
|
||||||
inputArray1 = vtkm::cont::make_ArrayHandle(inputVecs1, vtkm::CopyFlag::Off);
|
|
||||||
inputArray2 = vtkm::cont::make_ArrayHandle(inputVecs2, vtkm::CopyFlag::Off);
|
|
||||||
|
|
||||||
vtkm::worklet::DotProduct dotProductWorklet;
|
|
||||||
vtkm::worklet::DispatcherMapField<vtkm::worklet::DotProduct> dispatcherDotProduct(
|
|
||||||
dotProductWorklet);
|
|
||||||
dispatcherDotProduct.Invoke(inputArray1, inputArray2, outputArray);
|
|
||||||
|
|
||||||
VTKM_TEST_ASSERT(outputArray.GetNumberOfValues() == inputArray1.GetNumberOfValues(),
|
|
||||||
"Wrong number of results for DotProduct worklet");
|
|
||||||
|
|
||||||
for (vtkm::Id i = 0; i < inputArray1.GetNumberOfValues(); i++)
|
|
||||||
{
|
|
||||||
vtkm::Vec<T, 3> v1 = inputArray1.ReadPortal().Get(i);
|
|
||||||
vtkm::Vec<T, 3> v2 = inputArray2.ReadPortal().Get(i);
|
|
||||||
T ans = answer[static_cast<std::size_t>(i)];
|
|
||||||
|
|
||||||
VTKM_TEST_ASSERT(test_equal(ans, vtkm::Dot(v1, v2)), "Wrong result for dot product");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void TestDotProductWorklets()
|
|
||||||
{
|
|
||||||
std::cout << "Testing DotProduct Worklet" << std::endl;
|
|
||||||
TestDotProduct<vtkm::Float32>();
|
|
||||||
// TestDotProduct<vtkm::Float64>();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
int UnitTestDotProduct(int argc, char* argv[])
|
|
||||||
{
|
|
||||||
return vtkm::cont::testing::Testing::Run(TestDotProductWorklets, argc, argv);
|
|
||||||
}
|
|
Loading…
Reference in New Issue
Block a user