vtk-m/vtkm/worklet/NDimsHistMarginalization.h

201 lines
8.8 KiB
C++

//============================================================================
// Copyright (c) Kitware, Inc.
// All rights reserved.
// See LICENSE.txt for details.
//
// This software is distributed WITHOUT ANY WARRANTY; without even
// the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
// PURPOSE. See the above copyright notice for more information.
//============================================================================
#ifndef vtk_m_worklet_NDimsHistMarginalization_h
#define vtk_m_worklet_NDimsHistMarginalization_h
#include <vtkm/Math.h>
#include <vtkm/cont/Algorithm.h>
#include <vtkm/cont/ArrayCopy.h>
#include <vtkm/cont/ArrayHandle.h>
#include <vtkm/cont/ArrayHandleCounting.h>
#include <vtkm/cont/DataSet.h>
#include <vtkm/worklet/DispatcherMapField.h>
#include <vtkm/worklet/WorkletMapField.h>
#include <vtkm/worklet/histogram/ComputeNDHistogram.h>
#include <vtkm/worklet/histogram/MarginalizeNDHistogram.h>
#include <vtkm/cont/Field.h>
namespace vtkm
{
namespace worklet
{
class NDimsHistMarginalization
{
public:
// Execute the histogram (conditional) marginalization,
// given the multi-variable histogram(binId, freqIn)
// , marginalVariable and marginal condition
// Input arguments:
// binId, freqsIn: input ND-histogram in the fashion of sparse representation
// (definition of binId and frqIn please refer to NDimsHistogram.h),
// (binId.size() is the number of variables)
// numberOfBins: number of bins of each variable (length of numberOfBins must be the same as binId.size() )
// marginalVariables: length is the same as number of variables.
// 1 indicates marginal variable, otherwise 0.
// conditionFunc: The Condition function for non-marginal variable.
// This func takes two arguments (vtkm::Id var, vtkm::Id binId) and return bool
// var is index of variable and binId is bin index in the variable var
// return true indicates considering this bin into final marginal histogram
// more details can refer to example in UnitTestNDimsHistMarginalization.cxx
// marginalBinId, marginalFreqs: return marginalized histogram in the fashion of sparse representation
// the definition is the same as (binId and freqsIn)
template <typename BinaryCompare>
void Run(const std::vector<vtkm::cont::ArrayHandle<vtkm::Id>>& binId,
vtkm::cont::ArrayHandle<vtkm::Id>& freqsIn,
vtkm::cont::ArrayHandle<vtkm::Id>& numberOfBins,
vtkm::cont::ArrayHandle<bool>& marginalVariables,
BinaryCompare conditionFunc,
std::vector<vtkm::cont::ArrayHandle<vtkm::Id>>& marginalBinId,
vtkm::cont::ArrayHandle<vtkm::Id>& marginalFreqs)
{
//total variables
vtkm::Id numOfVariable = static_cast<vtkm::Id>(binId.size());
const vtkm::Id numberOfValues = freqsIn.GetNumberOfValues();
vtkm::cont::ArrayHandleConstant<vtkm::Id> constant0Array(0, numberOfValues);
vtkm::cont::ArrayHandle<vtkm::Id> bin1DIndex;
vtkm::cont::ArrayCopy(constant0Array, bin1DIndex);
vtkm::cont::ArrayHandle<vtkm::Id> freqs;
vtkm::cont::ArrayCopy(freqsIn, freqs);
vtkm::Id numMarginalVariables = 0; //count num of marginal variables
const auto marginalPortal = marginalVariables.ReadPortal();
const auto numBinsPortal = numberOfBins.ReadPortal();
for (vtkm::Id i = 0; i < numOfVariable; i++)
{
if (marginalPortal.Get(i) == true)
{
// Worklet to calculate 1D index for marginal variables
numMarginalVariables++;
const vtkm::Id nFieldBins = numBinsPortal.Get(i);
vtkm::worklet::histogram::To1DIndex binWorklet(nFieldBins);
vtkm::worklet::DispatcherMapField<vtkm::worklet::histogram::To1DIndex> to1DIndexDispatcher(
binWorklet);
size_t vecIndex = static_cast<size_t>(i);
to1DIndexDispatcher.Invoke(binId[vecIndex], bin1DIndex, bin1DIndex);
}
else
{ //non-marginal variable
// Worklet to set the frequency of entities which does not meet the condition
// to 0 on non-marginal variables
vtkm::worklet::histogram::ConditionalFreq<BinaryCompare> conditionalFreqWorklet{
conditionFunc
};
conditionalFreqWorklet.setVar(i);
vtkm::worklet::DispatcherMapField<vtkm::worklet::histogram::ConditionalFreq<BinaryCompare>>
cfDispatcher(conditionalFreqWorklet);
size_t vecIndex = static_cast<size_t>(i);
cfDispatcher.Invoke(binId[vecIndex], freqs, freqs);
}
}
// Sort the freq array for counting by key(1DIndex)
vtkm::cont::Algorithm::SortByKey(bin1DIndex, freqs);
// Add frequency within same 1d index bin (this get a nonSparse representation)
vtkm::cont::ArrayHandle<vtkm::Id> nonSparseMarginalFreqs;
vtkm::cont::Algorithm::ReduceByKey(
bin1DIndex, freqs, bin1DIndex, nonSparseMarginalFreqs, vtkm::Add());
// Convert to sparse representation(remove all zero freqncy entities)
vtkm::cont::ArrayHandle<vtkm::Id> sparseMarginal1DBinId;
vtkm::cont::Algorithm::CopyIf(bin1DIndex, nonSparseMarginalFreqs, sparseMarginal1DBinId);
vtkm::cont::Algorithm::CopyIf(nonSparseMarginalFreqs, nonSparseMarginalFreqs, marginalFreqs);
//convert back to multi variate binId
marginalBinId.resize(static_cast<size_t>(numMarginalVariables));
vtkm::Id marginalVarIdx = numMarginalVariables - 1;
for (vtkm::Id i = numOfVariable - 1; i >= 0; i--)
{
if (marginalPortal.Get(i) == true)
{
const vtkm::Id nFieldBins = numBinsPortal.Get(i);
vtkm::worklet::histogram::ConvertHistBinToND binWorklet(nFieldBins);
vtkm::worklet::DispatcherMapField<vtkm::worklet::histogram::ConvertHistBinToND>
convertHistBinToNDDispatcher(binWorklet);
size_t vecIndex = static_cast<size_t>(marginalVarIdx);
convertHistBinToNDDispatcher.Invoke(
sparseMarginal1DBinId, sparseMarginal1DBinId, marginalBinId[vecIndex]);
marginalVarIdx--;
}
}
} //Run()
// Execute the histogram marginalization WITHOUT CONDITION,
// Please refer to the other Run() functions for the definition of input arguments.
void Run(const std::vector<vtkm::cont::ArrayHandle<vtkm::Id>>& binId,
vtkm::cont::ArrayHandle<vtkm::Id>& freqsIn,
vtkm::cont::ArrayHandle<vtkm::Id>& numberOfBins,
vtkm::cont::ArrayHandle<bool>& marginalVariables,
std::vector<vtkm::cont::ArrayHandle<vtkm::Id>>& marginalBinId,
vtkm::cont::ArrayHandle<vtkm::Id>& marginalFreqs)
{
//total variables
vtkm::Id numOfVariable = static_cast<vtkm::Id>(binId.size());
const vtkm::Id numberOfValues = freqsIn.GetNumberOfValues();
vtkm::cont::ArrayHandleConstant<vtkm::Id> constant0Array(0, numberOfValues);
vtkm::cont::ArrayHandle<vtkm::Id> bin1DIndex;
vtkm::cont::ArrayCopy(constant0Array, bin1DIndex);
vtkm::cont::ArrayHandle<vtkm::Id> freqs;
vtkm::cont::ArrayCopy(freqsIn, freqs);
vtkm::Id numMarginalVariables = 0; //count num of marginal variables
const auto marginalPortal = marginalVariables.ReadPortal();
const auto numBinsPortal = numberOfBins.ReadPortal();
for (vtkm::Id i = 0; i < numOfVariable; i++)
{
if (marginalPortal.Get(i) == true)
{
// Worklet to calculate 1D index for marginal variables
numMarginalVariables++;
const vtkm::Id nFieldBins = numBinsPortal.Get(i);
vtkm::worklet::histogram::To1DIndex binWorklet(nFieldBins);
vtkm::worklet::DispatcherMapField<vtkm::worklet::histogram::To1DIndex> to1DIndexDispatcher(
binWorklet);
size_t vecIndex = static_cast<size_t>(i);
to1DIndexDispatcher.Invoke(binId[vecIndex], bin1DIndex, bin1DIndex);
}
}
// Sort the freq array for counting by key (1DIndex)
vtkm::cont::Algorithm::SortByKey(bin1DIndex, freqs);
// Add frequency within same 1d index bin
vtkm::cont::Algorithm::ReduceByKey(bin1DIndex, freqs, bin1DIndex, marginalFreqs, vtkm::Add());
//convert back to multi variate binId
marginalBinId.resize(static_cast<size_t>(numMarginalVariables));
vtkm::Id marginalVarIdx = numMarginalVariables - 1;
for (vtkm::Id i = numOfVariable - 1; i >= 0; i--)
{
if (marginalPortal.Get(i) == true)
{
const vtkm::Id nFieldBins = numBinsPortal.Get(i);
vtkm::worklet::histogram::ConvertHistBinToND binWorklet(nFieldBins);
vtkm::worklet::DispatcherMapField<vtkm::worklet::histogram::ConvertHistBinToND>
convertHistBinToNDDispatcher(binWorklet);
size_t vecIndex = static_cast<size_t>(marginalVarIdx);
convertHistBinToNDDispatcher.Invoke(bin1DIndex, bin1DIndex, marginalBinId[vecIndex]);
marginalVarIdx--;
}
}
} //Run()
};
}
} // namespace vtkm::worklet
#endif // vtk_m_worklet_NDimsHistMarginalization_h