diff --git a/.gitattributes b/.gitattributes index a7a0795f1..79ac99f3f 100644 --- a/.gitattributes +++ b/.gitattributes @@ -13,3 +13,5 @@ data/* filter=lfs diff=lfs merge=lfs -text *.md whitespace=tab-in-indent conflict-marker-size=79 -whitespace *.rst whitespace=tab-in-indent conflict-marker-size=79 *.txt whitespace=tab-in-indent + +diy/** -format.clang-format -whitespace diff --git a/CMake/VTKmCheckCopyright.cmake b/CMake/VTKmCheckCopyright.cmake index f9e4067a5..cddeb3f73 100644 --- a/CMake/VTKmCheckCopyright.cmake +++ b/CMake/VTKmCheckCopyright.cmake @@ -39,6 +39,9 @@ set(FILES_TO_CHECK set(EXCEPTIONS LICENSE.txt README.txt + diy/include/diy + diy/LEGAL.txt + diy/LICENSE.txt ) if (NOT VTKm_SOURCE_DIR) diff --git a/CMake/VTKmConfig.cmake.in b/CMake/VTKmConfig.cmake.in index 75183cbc1..05aa8f7e7 100755 --- a/CMake/VTKmConfig.cmake.in +++ b/CMake/VTKmConfig.cmake.in @@ -62,6 +62,7 @@ set(VTKm_ENABLE_CUDA "@VTKm_ENABLE_CUDA@") set(VTKm_ENABLE_TBB "@VTKm_ENABLE_TBB@") set(VTKm_ENABLE_RENDERING "@VTKm_ENABLE_RENDERING@") set(VTKm_RENDERING_BACKEND "@VTKm_RENDERING_BACKEND@") +set(VTKm_ENABLE_MPI "@VTKm_ENABLE_MPI@") # Load the library exports, but only if not compiling VTK-m itself set_and_check(VTKm_CONFIG_DIR "@PACKAGE_VTKm_INSTALL_CONFIG_DIR@") diff --git a/CMake/VTKmMacros.cmake b/CMake/VTKmMacros.cmake deleted file mode 100644 index 68127826d..000000000 --- a/CMake/VTKmMacros.cmake +++ /dev/null @@ -1,833 +0,0 @@ -##============================================================================ -## Copyright (c) Kitware, Inc. -## All rights reserved. -## See LICENSE.txt for details. -## This software is distributed WITHOUT ANY WARRANTY; without even -## the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR -## PURPOSE. See the above copyright notice for more information. -## -## Copyright 2014 National Technology & Engineering Solutions of Sandia, LLC (NTESS). -## Copyright 2014 UT-Battelle, LLC. -## Copyright 2014 Los Alamos National Security. -## -## Under the terms of Contract DE-NA0003525 with NTESS, -## the U.S. Government retains certain rights in this software. -## -## Under the terms of Contract DE-AC52-06NA25396 with Los Alamos National -## Laboratory (LANL), the U.S. Government retains certain rights in -## this software. -##============================================================================ - -include(CMakeParseArguments) - -# Utility to build a kit name from the current directory. -function(vtkm_get_kit_name kitvar) - # Will this always work? It should if ${CMAKE_CURRENT_SOURCE_DIR} is - # built from ${VTKm_SOURCE_DIR}. - string(REPLACE "${VTKm_SOURCE_DIR}/" "" dir_prefix ${CMAKE_CURRENT_SOURCE_DIR}) - string(REPLACE "/" "_" kit "${dir_prefix}") - set(${kitvar} "${kit}" PARENT_SCOPE) - # Optional second argument to get dir_prefix. - if (${ARGC} GREATER 1) - set(${ARGV1} "${dir_prefix}" PARENT_SCOPE) - endif (${ARGC} GREATER 1) -endfunction(vtkm_get_kit_name) - - -#Utility to setup nvcc flags so that we properly work around issues inside FindCUDA. -#if we are generating cu files need to setup four things. -#1. Explicitly set the cuda device adapter as a define this is currently -# done as a work around since the cuda executable ignores compile -# definitions -#2. Disable unused function warnings -# the FindCUDA module and helper methods don't read target level -# properties so we have to modify CUDA_NVCC_FLAGS instead of using -# target and source level COMPILE_FLAGS and COMPILE_DEFINITIONS -#3. Set the compile option /bigobj when using VisualStudio generators -# While we have specified this as target compile flag, those aren't -# currently loooked at by FindCUDA, so we have to manually add it ourselves -function(vtkm_setup_nvcc_flags old_nvcc_flags old_cxx_flags ) - set(${old_nvcc_flags} ${CUDA_NVCC_FLAGS} PARENT_SCOPE) - set(${old_nvcc_flags} ${CMAKE_CXX_FLAGS} PARENT_SCOPE) - set(new_nvcc_flags ${CUDA_NVCC_FLAGS}) - set(new_cxx_flags ${CMAKE_CXX_FLAGS}) - list(APPEND new_nvcc_flags "-DVTKM_DEVICE_ADAPTER=VTKM_DEVICE_ADAPTER_CUDA") - list(APPEND new_nvcc_flags "-w") - if(MSVC) - list(APPEND new_nvcc_flags "--compiler-options;/bigobj") - - # The MSVC compiler gives a warning about having two incompatiable warning - # flags in the command line. So, ironically, adding -w above to remove - # warnings makes MSVC give a warning. To get around that, remove all - # warning flags from the standard CXX arguments (which are typically passed - # to the CUDA compiler). - string(REGEX REPLACE "[-/]W[1-4]" "" new_cxx_flags "${new_cxx_flags}") - string(REGEX REPLACE "[-/]Wall" "" new_cxx_flags "${new_cxx_flags}") - endif() - set(CUDA_NVCC_FLAGS ${new_nvcc_flags} PARENT_SCOPE) - set(CMAKE_CXX_FLAGS ${new_cxx_flags} PARENT_SCOPE) -endfunction(vtkm_setup_nvcc_flags) - -#Utility to set MSVC only COMPILE_DEFINITIONS and COMPILE_FLAGS needed to -#reduce number of warnings and compile issues with Visual Studio -function(vtkm_setup_msvc_properties target ) - if(NOT MSVC) - return() - endif() - - #disable MSVC CRT and SCL warnings as they recommend using non standard - #c++ extensions - target_compile_definitions(${target} PRIVATE "_SCL_SECURE_NO_WARNINGS" - "_CRT_SECURE_NO_WARNINGS") - - #C4702 Generates numerous false positives with template code about - # unreachable code - #C4505 Generates numerous warnings about unused functions being - # removed when doing header test builds. - target_compile_options(${target} PRIVATE -wd4702 -wd4505) - - # In VS2013 the C4127 warning has a bug in the implementation and - # generates false positive warnings for lots of template code - if(MSVC_VERSION LESS 1900) - target_compile_options(${target} PRIVATE -wd4127 ) - endif() - -endfunction(vtkm_setup_msvc_properties) - -# vtkm_target_name() -# -# This macro does some basic checking for library naming, and also adds a suffix -# to the output name with the VTKm version by default. Setting the variable -# VTKm_CUSTOM_LIBRARY_SUFFIX will override the suffix. -function(vtkm_target_name _name) - get_property(_type TARGET ${_name} PROPERTY TYPE) - if(NOT "${_type}" STREQUAL EXECUTABLE) - set_property(TARGET ${_name} PROPERTY VERSION 1) - set_property(TARGET ${_name} PROPERTY SOVERSION 1) - endif() - if("${_name}" MATCHES "^[Vv][Tt][Kk][Mm]") - set(_vtkm "") - else() - set(_vtkm "vtkm") - #message(AUTHOR_WARNING "Target [${_name}] does not start in 'vtkm'.") - endif() - # Support custom library suffix names, for other projects wanting to inject - # their own version numbers etc. - if(DEFINED VTKm_CUSTOM_LIBRARY_SUFFIX) - set(_lib_suffix "${VTKm_CUSTOM_LIBRARY_SUFFIX}") - else() - set(_lib_suffix "-${VTKm_VERSION_MAJOR}.${VTKm_VERSION_MINOR}") - endif() - set_property(TARGET ${_name} PROPERTY OUTPUT_NAME ${_vtk}${_name}${_lib_suffix}) -endfunction() - -function(vtkm_target _name) - vtkm_target_name(${_name}) -endfunction() - -# Builds a source file and an executable that does nothing other than -# compile the given header files. -function(vtkm_add_header_build_test name dir_prefix use_cuda) - set(hfiles ${ARGN}) - if (use_cuda) - set(suffix ".cu") - else (use_cuda) - set(suffix ".cxx") - endif (use_cuda) - set(cxxfiles) - foreach (header ${ARGN}) - get_source_file_property(cant_be_tested ${header} VTKm_CANT_BE_HEADER_TESTED) - - if( NOT cant_be_tested ) - string(REPLACE "${CMAKE_CURRENT_BINARY_DIR}" "" header "${header}") - get_filename_component(headername ${header} NAME_WE) - set(src ${CMAKE_CURRENT_BINARY_DIR}/TB_${headername}${suffix}) - configure_file(${VTKm_SOURCE_DIR}/CMake/TestBuild.cxx.in ${src} @ONLY) - list(APPEND cxxfiles ${src}) - endif() - - endforeach (header) - - #only attempt to add a test build executable if we have any headers to - #test. this might not happen when everything depends on thrust. - list(LENGTH cxxfiles cxxfiles_len) - if (use_cuda AND ${cxxfiles_len} GREATER 0) - - vtkm_setup_nvcc_flags( old_nvcc_flags old_cxx_flags ) - - # Cuda compiles do not respect target_include_directories - # and we want system includes so we have to hijack cuda - # to do it - foreach(dir ${VTKm_INCLUDE_DIRS}) - #this internal variable has changed names depending on the CMake ver - list(APPEND CUDA_NVCC_INCLUDE_ARGS_USER -isystem ${dir}) - list(APPEND CUDA_NVCC_INCLUDE_DIRS_USER -isystem ${dir}) - endforeach() - - cuda_include_directories(${VTKm_SOURCE_DIR} - ${VTKm_BINARY_INCLUDE_DIR} - ) - - cuda_add_library(TestBuild_${name} STATIC ${cxxfiles} ${hfiles}) - - - set(CUDA_NVCC_FLAGS ${old_nvcc_flags}) - set(CMAKE_CXX_FLAGS ${old_cxx_flags}) - - elseif (${cxxfiles_len} GREATER 0) - add_library(TestBuild_${name} STATIC ${cxxfiles} ${hfiles}) - target_include_directories(TestBuild_${name} PRIVATE vtkm ${VTKm_INCLUDE_DIRS}) - endif () - target_link_libraries(TestBuild_${name} PRIVATE vtkm_cont ${VTKm_LIBRARIES}) - set_source_files_properties(${hfiles} - PROPERTIES HEADER_FILE_ONLY TRUE - ) - - vtkm_setup_msvc_properties(TestBuild_${name}) - - # Send the libraries created for test builds to their own directory so as to - # not polute the directory with useful libraries. - set_target_properties(TestBuild_${name} PROPERTIES - ARCHIVE_OUTPUT_DIRECTORY ${VTKm_LIBRARY_OUTPUT_PATH}/testbuilds - LIBRARY_OUTPUT_DIRECTORY ${VTKm_LIBRARY_OUTPUT_PATH}/testbuilds - RUNTIME_OUTPUT_DIRECTORY ${VTKm_LIBRARY_OUTPUT_PATH}/testbuilds - ) -endfunction(vtkm_add_header_build_test) - -function(vtkm_install_headers dir_prefix) - set(hfiles ${ARGN}) - install(FILES ${hfiles} - DESTINATION ${VTKm_INSTALL_INCLUDE_DIR}/${dir_prefix} - ) -endfunction(vtkm_install_headers) - -function(vtkm_install_template_sources) - vtkm_get_kit_name(name dir_prefix) - set(hfiles ${ARGN}) - vtkm_install_headers("${dir_prefix}" ${hfiles}) - # CMake does not add installed files as project files, and template sources - # are not declared as source files anywhere, add a fake target here to let - # an IDE know that these sources exist. - add_custom_target(${name}_template_srcs SOURCES ${hfiles}) -endfunction(vtkm_install_template_sources) - -# Declare a list of headers that require thrust to be enabled -# for them to header tested. In cases of thrust version 1.5 or less -# we have to make sure openMP is enabled, otherwise we are okay -function(vtkm_requires_thrust_to_test) - #determine the state of thrust and testing - set(cant_be_tested FALSE) - if(NOT VTKm_ENABLE_THRUST) - #mark as not valid - set(cant_be_tested TRUE) - elseif(NOT VTKm_ENABLE_OPENMP) - #mark also as not valid - set(cant_be_tested TRUE) - endif() - - foreach(header ${ARGN}) - #set a property on the file that marks if we can header test it - set_source_files_properties( ${header} - PROPERTIES VTKm_CANT_BE_HEADER_TESTED ${cant_be_tested} ) - - endforeach(header) - -endfunction(vtkm_requires_thrust_to_test) - -# Declare a list of header files. Will make sure the header files get -# compiled and show up in an IDE. -function(vtkm_declare_headers) - set(options CUDA) - set(oneValueArgs TESTABLE) - set(multiValueArgs) - cmake_parse_arguments(VTKm_DH "${options}" - "${oneValueArgs}" "${multiValueArgs}" - ${ARGN} - ) - - #The testable keyword allows the caller to turn off the header testing, - #mainly used so that backends can be installed even when they can't be - #built on the machine. - #Since this is an optional property not setting it means you do want testing - if(NOT DEFINED VTKm_DH_TESTABLE) - set(VTKm_DH_TESTABLE ON) - endif() - - set(hfiles ${VTKm_DH_UNPARSED_ARGUMENTS}) - vtkm_get_kit_name(name dir_prefix) - - #only do header testing if enable testing is turned on - if (VTKm_ENABLE_TESTING AND VTKm_DH_TESTABLE) - vtkm_add_header_build_test( - "${name}" "${dir_prefix}" "${VTKm_DH_CUDA}" ${hfiles}) - endif() - #always install headers - vtkm_install_headers("${dir_prefix}" ${hfiles}) -endfunction(vtkm_declare_headers) - -# Declare a list of worklet files. -function(vtkm_declare_worklets) - # Currently worklets are just really header files. - vtkm_declare_headers(${ARGN}) -endfunction(vtkm_declare_worklets) - -function(vtkm_pyexpander_generated_file generated_file_name) - # If pyexpander is available, add targets to build and check - if(PYEXPANDER_FOUND AND PYTHONINTERP_FOUND) - add_custom_command( - OUTPUT ${CMAKE_CURRENT_BINARY_DIR}/${generated_file_name}.checked - COMMAND ${CMAKE_COMMAND} - -DPYTHON_EXECUTABLE=${PYTHON_EXECUTABLE} - -DPYEXPANDER_COMMAND=${PYEXPANDER_COMMAND} - -DSOURCE_FILE=${CMAKE_CURRENT_SOURCE_DIR}/${generated_file_name} - -DGENERATED_FILE=${CMAKE_CURRENT_BINARY_DIR}/${generated_file_name} - -P ${VTKm_CMAKE_MODULE_PATH}/VTKmCheckPyexpander.cmake - MAIN_DEPENDENCY ${CMAKE_CURRENT_SOURCE_DIR}/${generated_file_name}.in - DEPENDS ${CMAKE_CURRENT_SOURCE_DIR}/${generated_file_name} - COMMENT "Checking validity of ${generated_file_name}" - ) - add_custom_target(check_${generated_file_name} ALL - DEPENDS ${CMAKE_CURRENT_BINARY_DIR}/${generated_file_name}.checked - ) - endif() -endfunction(vtkm_pyexpander_generated_file) - -# Declare unit tests, which should be in the same directory as a kit -# (package, module, whatever you call it). Usage: -# -# vtkm_unit_tests( -# SOURCES -# LIBRARIES -# TEST_ARGS -# ) -function(vtkm_unit_tests) - set(options CUDA) - set(oneValueArgs) - set(multiValueArgs SOURCES LIBRARIES TEST_ARGS) - cmake_parse_arguments(VTKm_UT - "${options}" "${oneValueArgs}" "${multiValueArgs}" - ${ARGN} - ) - - if (VTKm_ENABLE_TESTING) - vtkm_get_kit_name(kit) - #we use UnitTests_ so that it is an unique key to exclude from coverage - set(test_prog UnitTests_${kit}) - create_test_sourcelist(TestSources ${test_prog}.cxx ${VTKm_UT_SOURCES}) - - #determine the timeout for all the tests based on the backend. CUDA tests - #generally require more time because of kernel generation. - set(timeout 180) - if (VTKm_UT_CUDA) - set(timeout 1500) - endif() - - if (VTKm_UT_CUDA) - vtkm_setup_nvcc_flags( old_nvcc_flags old_cxx_flags ) - - # Cuda compiles do not respect target_include_directories - cuda_include_directories(${VTKm_SOURCE_DIR} - ${VTKm_BINARY_INCLUDE_DIR} - ${VTKm_INCLUDE_DIRS} - ) - - cuda_add_executable(${test_prog} ${TestSources}) - - set(CUDA_NVCC_FLAGS ${old_nvcc_flags}) - set(CMAKE_CXX_FLAGS ${old_cxx_flags}) - - else (VTKm_UT_CUDA) - add_executable(${test_prog} ${TestSources}) - endif (VTKm_UT_CUDA) - - set_target_properties(${test_prog} PROPERTIES - ARCHIVE_OUTPUT_DIRECTORY ${VTKm_LIBRARY_OUTPUT_PATH} - LIBRARY_OUTPUT_DIRECTORY ${VTKm_LIBRARY_OUTPUT_PATH} - RUNTIME_OUTPUT_DIRECTORY ${VTKm_EXECUTABLE_OUTPUT_PATH} - ) - - #do it as a property value so we don't pollute the include_directories - #for any other targets - target_include_directories(${test_prog} PRIVATE ${VTKm_INCLUDE_DIRS}) - - target_link_libraries(${test_prog} PRIVATE vtkm_cont ${VTKm_LIBRARIES}) - - target_compile_options(${test_prog} PRIVATE ${VTKm_COMPILE_OPTIONS}) - - vtkm_setup_msvc_properties(${test_prog}) - - foreach (test ${VTKm_UT_SOURCES}) - get_filename_component(tname ${test} NAME_WE) - add_test(NAME ${tname} - COMMAND ${test_prog} ${tname} ${VTKm_UT_TEST_ARGS} - ) - set_tests_properties("${tname}" PROPERTIES TIMEOUT ${timeout}) - endforeach (test) - endif (VTKm_ENABLE_TESTING) - -endfunction(vtkm_unit_tests) - -# Save the worklets to test with each device adapter -# Usage: -# -# vtkm_save_worklet_unit_tests( sources ) -# -# notes: will save the sources absolute path as the -# vtkm_source_worklet_unit_tests global property -function(vtkm_save_worklet_unit_tests ) - - #create the test driver when we are called, since - #the test driver expect the test files to be in the same - #directory as the test driver - create_test_sourcelist(test_sources WorkletTestDriver.cxx ${ARGN}) - - #store the absolute path for the test drive and all the test - #files - set(driver ${CMAKE_CURRENT_BINARY_DIR}/WorkletTestDriver.cxx) - set(cxx_sources) - set(cu_sources) - - #we need to store the absolute source for the file so that - #we can properly compile it into the test driver. At - #the same time we want to configure each file into the build - #directory as a .cu file so that we can compile it with cuda - #if needed - foreach(fname ${ARGN}) - set(absPath) - - get_filename_component(absPath ${fname} ABSOLUTE) - get_filename_component(file_name_only ${fname} NAME_WE) - - set(cuda_file_name "${CMAKE_CURRENT_BINARY_DIR}/${file_name_only}.cu") - configure_file("${absPath}" - "${cuda_file_name}" - COPYONLY) - list(APPEND cxx_sources ${absPath}) - list(APPEND cu_sources ${cuda_file_name}) - endforeach() - - #we create a property that holds all the worklets to test, - #but don't actually attempt to create a unit test with the yet. - #That is done by each device adapter - set_property( GLOBAL APPEND - PROPERTY vtkm_worklet_unit_tests_sources ${cxx_sources}) - set_property( GLOBAL APPEND - PROPERTY vtkm_worklet_unit_tests_cu_sources ${cu_sources}) - set_property( GLOBAL APPEND - PROPERTY vtkm_worklet_unit_tests_drivers ${driver}) - -endfunction(vtkm_save_worklet_unit_tests) - -# Call each worklet test for the given device adapter -# Usage: -# -# vtkm_worklet_unit_tests( device_adapter ) -# -# notes: will look for the vtkm_source_worklet_unit_tests global -# property to find what are the worklet unit tests that need to be -# compiled for the give device adapter -function(vtkm_worklet_unit_tests device_adapter) - - set(unit_test_srcs) - get_property(unit_test_srcs GLOBAL - PROPERTY vtkm_worklet_unit_tests_sources ) - - set(unit_test_drivers) - get_property(unit_test_drivers GLOBAL - PROPERTY vtkm_worklet_unit_tests_drivers ) - - #detect if we are generating a .cu files - set(is_cuda FALSE) - if("${device_adapter}" STREQUAL "VTKM_DEVICE_ADAPTER_CUDA") - set(is_cuda TRUE) - endif() - - #determine the timeout for all the tests based on the backend. The first CUDA - #worklet test requires way more time because of the overhead to allow the - #driver to convert the kernel code from virtual arch to actual arch. - # - set(timeout 180) - if(is_cuda) - set(timeout 1500) - endif() - - if(VTKm_ENABLE_TESTING) - string(REPLACE "VTKM_DEVICE_ADAPTER_" "" device_type ${device_adapter}) - - vtkm_get_kit_name(kit) - - #inject the device adapter into the test program name so each one is unique - set(test_prog WorkletTests_${device_type}) - - - if(is_cuda) - get_property(unit_test_srcs GLOBAL PROPERTY vtkm_worklet_unit_tests_cu_sources ) - vtkm_setup_nvcc_flags( old_nvcc_flags old_cxx_flags ) - - # Cuda compiles do not respect target_include_directories - cuda_include_directories(${VTKm_SOURCE_DIR} - ${VTKm_BINARY_INCLUDE_DIR} - ${VTKm_INCLUDE_DIRS} - ) - - cuda_add_executable(${test_prog} ${unit_test_drivers} ${unit_test_srcs}) - - set(CUDA_NVCC_FLAGS ${old_nvcc_flags}) - set(CMAKE_CXX_FLAGS ${old_cxx_flags}) - else() - add_executable(${test_prog} ${unit_test_drivers} ${unit_test_srcs}) - endif() - - set_target_properties(${test_prog} PROPERTIES - ARCHIVE_OUTPUT_DIRECTORY ${VTKm_LIBRARY_OUTPUT_PATH} - LIBRARY_OUTPUT_DIRECTORY ${VTKm_LIBRARY_OUTPUT_PATH} - RUNTIME_OUTPUT_DIRECTORY ${VTKm_EXECUTABLE_OUTPUT_PATH} - ) - target_include_directories(${test_prog} PRIVATE ${VTKm_INCLUDE_DIRS}) - target_link_libraries(${test_prog} PRIVATE vtkm_cont ${VTKm_LIBRARIES}) - - #add the specific compile options for this executable - target_compile_options(${test_prog} PRIVATE ${VTKm_COMPILE_OPTIONS}) - - #add a test for each worklet test file. We will inject the device - #adapter type into the test name so that it is easier to see what - #exact device a test is failing on. - foreach (test ${unit_test_srcs}) - get_filename_component(tname ${test} NAME_WE) - add_test(NAME "${tname}${device_type}" - COMMAND ${test_prog} ${tname} - ) - - set_tests_properties("${tname}${device_type}" PROPERTIES TIMEOUT ${timeout}) - endforeach (test) - - vtkm_setup_msvc_properties(${test_prog}) - - #set the device adapter on the executable - target_compile_definitions(${test_prog} PRIVATE "VTKM_DEVICE_ADAPTER=${device_adapter}") - endif() -endfunction(vtkm_worklet_unit_tests) - -# Save the benchmarks to run with each device adapter -# This is based on vtkm_save_worklet_unit_tests -# Usage: -# -# vtkm_save_benchmarks( [HEADERS ] ) -# -# -# Each benchmark source file needs to implement main(int agrc, char *argv[]) -# -# notes: will save the sources absolute path as the -# vtkm_benchmarks_sources global property -function(vtkm_save_benchmarks) - - #store the absolute path for all the test files - set(cxx_sources) - set(cu_sources) - - cmake_parse_arguments(save_benchmarks "" "" "HEADERS" ${ARGN}) - - #we need to store the absolute source for the file so that - #we can properly compile it into the benchmark driver. At - #the same time we want to configure each file into the build - #directory as a .cu file so that we can compile it with cuda - #if needed - foreach(fname ${save_benchmarks_UNPARSED_ARGUMENTS}) - set(absPath) - - get_filename_component(absPath ${fname} ABSOLUTE) - get_filename_component(file_name_only ${fname} NAME_WE) - - set(cuda_file_name "${CMAKE_CURRENT_BINARY_DIR}/${file_name_only}.cu") - configure_file("${absPath}" - "${cuda_file_name}" - COPYONLY) - list(APPEND cxx_sources ${absPath}) - list(APPEND cu_sources ${cuda_file_name}) - endforeach() - - #we create a property that holds all the worklets to test, - #but don't actually attempt to create a unit test with the yet. - #That is done by each device adapter - set_property( GLOBAL APPEND - PROPERTY vtkm_benchmarks_sources ${cxx_sources}) - set_property( GLOBAL APPEND - PROPERTY vtkm_benchmarks_cu_sources ${cu_sources}) - set_property( GLOBAL APPEND - PROPERTY vtkm_benchmarks_headers ${save_benchmarks_HEADERS}) - -endfunction(vtkm_save_benchmarks) - -# Call each benchmark for the given device adapter -# Usage: -# -# vtkm_benchmark( device_adapter ) -# -# notes: will look for the vtkm_benchmarks_sources global -# property to find what are the benchmarks that need to be -# compiled for the give device adapter -function(vtkm_benchmarks device_adapter) - - set(benchmark_srcs) - get_property(benchmark_srcs GLOBAL - PROPERTY vtkm_benchmarks_sources ) - - set(benchmark_headers) - get_property(benchmark_headers GLOBAL - PROPERTY vtkm_benchmarks_headers ) - - #detect if we are generating a .cu files - set(is_cuda FALSE) - set(old_nvcc_flags ${CUDA_NVCC_FLAGS}) - set(old_cxx_flags ${CMAKE_CXX_FLAGS}) - if("${device_adapter}" STREQUAL "VTKM_DEVICE_ADAPTER_CUDA") - set(is_cuda TRUE) - endif() - - if(VTKm_ENABLE_BENCHMARKS) - string(REPLACE "VTKM_DEVICE_ADAPTER_" "" device_type ${device_adapter}) - - if(is_cuda) - vtkm_setup_nvcc_flags( old_nvcc_flags old_cxx_flags ) - get_property(benchmark_srcs GLOBAL PROPERTY vtkm_benchmarks_cu_sources ) - endif() - - foreach( file ${benchmark_srcs}) - #inject the device adapter into the benchmark program name so each one is unique - get_filename_component(benchmark_prog ${file} NAME_WE) - set(benchmark_prog "${benchmark_prog}_${device_type}") - - if(is_cuda) - # Cuda compiles do not respect target_include_directories - - cuda_include_directories(${VTKm_SOURCE_DIR} - ${VTKm_BINARY_INCLUDE_DIR} - ${VTKm_BACKEND_INCLUDE_DIRS} - ) - - cuda_add_executable(${benchmark_prog} ${file} ${benchmark_headers}) - else() - add_executable(${benchmark_prog} ${file} ${benchmark_headers}) - endif() - - set_target_properties(${benchmark_prog} PROPERTIES - ARCHIVE_OUTPUT_DIRECTORY ${VTKm_LIBRARY_OUTPUT_PATH} - LIBRARY_OUTPUT_DIRECTORY ${VTKm_LIBRARY_OUTPUT_PATH} - RUNTIME_OUTPUT_DIRECTORY ${VTKm_EXECUTABLE_OUTPUT_PATH} - ) - - set_source_files_properties(${benchmark_headers} - PROPERTIES HEADER_FILE_ONLY TRUE) - - target_include_directories(${benchmark_prog} PRIVATE ${VTKm_BACKEND_INCLUDE_DIRS}) - target_link_libraries(${benchmark_prog} PRIVATE vtkm_cont ${VTKm_BACKEND_LIBRARIES}) - - vtkm_setup_msvc_properties(${benchmark_prog}) - - #add the specific compile options for this executable - target_compile_options(${benchmark_prog} PRIVATE ${VTKm_COMPILE_OPTIONS}) - - #set the device adapter on the executable - target_compile_definitions(${benchmark_prog} PRIVATE "VTKM_DEVICE_ADAPTER=${device_adapter}") - - endforeach() - - if(is_cuda) - set(CUDA_NVCC_FLAGS ${old_nvcc_flags}) - set(CMAKE_CXX_FLAGS ${old_cxx_flags}) - endif() - endif() - -endfunction(vtkm_benchmarks) - -# Given a list of *.cxx source files that during configure time are deterimined -# to have CUDA code, wrap the sources in *.cu files so that they get compiled -# with nvcc. -function(vtkm_wrap_sources_for_cuda cuda_source_list_var) - set(original_sources ${ARGN}) - - set(cuda_sources) - foreach(source_file ${original_sources}) - get_filename_component(source_name ${source_file} NAME_WE) - get_filename_component(source_file_path ${source_file} ABSOLUTE) - set(wrapped_file ${CMAKE_CURRENT_BINARY_DIR}/${source_name}.cu) - configure_file( - ${VTKm_SOURCE_DIR}/CMake/WrapCUDASource.cu.in - ${wrapped_file} - @ONLY) - list(APPEND cuda_sources ${wrapped_file}) - endforeach(source_file) - - # Set original sources as header files (which they basically are) so that - # we can add them to the file list and they will show up in IDE but they will - # not be compiled separately. - set_source_files_properties(${original_sources} - PROPERTIES HEADER_FILE_ONLY TRUE - ) - set(${cuda_source_list_var} ${cuda_sources} ${original_sources} PARENT_SCOPE) -endfunction(vtkm_wrap_sources_for_cuda) - -# Add a VTK-m library. The name of the library will match the "kit" name -# (e.g. vtkm_rendering) unless the NAME argument is given. -# -# vtkm_library( -# [NAME ] -# SOURCES -# [HEADERS ] -# [CUDA] -# [WRAP_FOR_CUDA ] -# [LIBRARIES ] -# ) -function(vtkm_library) - set(options CUDA) - set(oneValueArgs NAME) - set(multiValueArgs SOURCES HEADERS WRAP_FOR_CUDA) - cmake_parse_arguments(VTKm_LIB - "${options}" "${oneValueArgs}" "${multiValueArgs}" - ${ARGN} - ) - - vtkm_get_kit_name(kit dir_prefix) - if(VTKm_LIB_NAME) - set(lib_name ${VTKm_LIB_NAME}) - else() - set(lib_name ${kit}) - endif() - - list(APPEND VTKm_LIB_SOURCES ${VTKm_LIB_HEADERS}) - set_source_files_properties(${VTKm_LIB_HEADERS} - PROPERTIES HEADER_FILE_ONLY TRUE - ) - - if(VTKm_LIB_WRAP_FOR_CUDA) - if(VTKm_ENABLE_CUDA) - # If we have some sources marked as WRAP_FOR_CUDA and we support CUDA, - # then we need to turn on CDUA, wrap those sources, and add the wrapped - # code to the sources list. - set(VTKm_LIB_CUDA TRUE) - vtkm_wrap_sources_for_cuda(cuda_sources ${VTKm_LIB_WRAP_FOR_CUDA}) - list(APPEND VTKm_LIB_SOURCES ${cuda_sources}) - else() - # If we have some sources marked as WRAP_FOR_CUDA but we do not support - # CUDA, then just compile these sources normally by adding them to the - # sources list. - list(APPEND VTKm_LIB_SOURCES ${VTKm_LIB_WRAP_FOR_CUDA}) - endif() - endif() - - if(VTKm_LIB_CUDA) - vtkm_setup_nvcc_flags(old_nvcc_flags old_cxx_flags) - - - # Cuda compiles do not respect target_include_directories - cuda_include_directories(${VTKm_SOURCE_DIR} - ${VTKm_BINARY_INCLUDE_DIR} - ${VTKm_BACKEND_INCLUDE_DIRS} - ) - - if(BUILD_SHARED_LIBS AND NOT WIN32) - set(compile_options -Xcompiler=${CMAKE_CXX_COMPILE_OPTIONS_VISIBILITY}hidden) - endif() - - cuda_add_library(${lib_name} ${VTKm_LIB_SOURCES} - OPTIONS "${compile_options}") - - set(CUDA_NVCC_FLAGS ${old_nvcc_flags}) - set(CMAKE_CXX_FLAGS ${old_cxx_flags}) - else() - add_library(${lib_name} ${VTKm_LIB_SOURCES}) - endif() - - vtkm_target(${lib_name}) - - target_link_libraries(${lib_name} PUBLIC vtkm) - target_link_libraries(${lib_name} PRIVATE - ${VTKm_BACKEND_LIBRARIES} - ${VTKm_LIB_LIBRARIES} - ) - - set(cxx_args ${VTKm_COMPILE_OPTIONS}) - separate_arguments(cxx_args) - target_compile_options(${lib_name} PRIVATE ${cxx_args}) - - # Make sure libraries go to lib directory and dll go to bin directory. - # Mostly important on Windows. - set_target_properties(${lib_name} PROPERTIES - ARCHIVE_OUTPUT_DIRECTORY ${VTKm_LIBRARY_OUTPUT_PATH} - LIBRARY_OUTPUT_DIRECTORY ${VTKm_LIBRARY_OUTPUT_PATH} - RUNTIME_OUTPUT_DIRECTORY ${VTKm_EXECUTABLE_OUTPUT_PATH} - ) - - vtkm_setup_msvc_properties(${lib_name}) - - if(VTKm_EXTRA_COMPILER_WARNINGS) - set(cxx_args ${CMAKE_CXX_FLAGS_WARN_EXTRA}) - separate_arguments(cxx_args) - target_compile_options(${lib_name} - PRIVATE ${cxx_args} - ) - endif(VTKm_EXTRA_COMPILER_WARNINGS) - - #Now generate a header that holds the macros needed to easily export - #template classes. This - string(TOUPPER ${lib_name} BASE_NAME_UPPER) - set(EXPORT_MACRO_NAME "${BASE_NAME_UPPER}") - - set(EXPORT_IS_BUILT_STATIC 0) - get_target_property(is_static ${lib_name} TYPE) - if(${is_static} STREQUAL "STATIC_LIBRARY") - #If we are building statically set the define symbol - set(EXPORT_IS_BUILT_STATIC 1) - endif() - unset(is_static) - - get_target_property(EXPORT_IMPORT_CONDITION ${lib_name} DEFINE_SYMBOL) - if(NOT EXPORT_IMPORT_CONDITION) - #set EXPORT_IMPORT_CONDITION to what the DEFINE_SYMBOL would be when - #building shared - set(EXPORT_IMPORT_CONDITION ${lib_name}_EXPORTS) - endif() - - configure_file( - ${VTKm_SOURCE_DIR}/CMake/VTKmExportHeaderTemplate.h.in - ${VTKm_BINARY_INCLUDE_DIR}/${dir_prefix}/${lib_name}_export.h - @ONLY) - - unset(EXPORT_MACRO_NAME) - unset(EXPORT_IS_BUILT_STATIC) - unset(EXPORT_IMPORT_CONDITION) - - install(TARGETS ${lib_name} - EXPORT ${VTKm_EXPORT_NAME} - ARCHIVE DESTINATION ${VTKm_INSTALL_LIB_DIR} - LIBRARY DESTINATION ${VTKm_INSTALL_LIB_DIR} - RUNTIME DESTINATION ${VTKm_INSTALL_BIN_DIR} - ) - vtkm_install_headers("${dir_prefix}" - ${VTKm_BINARY_INCLUDE_DIR}/${dir_prefix}/${lib_name}_export.h - ${VTKm_LIB_HEADERS} - ) -endfunction(vtkm_library) - -# The Thrust project is not as careful as the VTKm project in avoiding warnings -# on shadow variables and unused arguments. With a real GCC compiler, you -# can disable these warnings inline, but with something like nvcc, those -# pragmas cause errors. Thus, this macro will disable the compiler warnings. -macro(vtkm_disable_troublesome_thrust_warnings) - vtkm_disable_troublesome_thrust_warnings_var(CMAKE_CXX_FLAGS_DEBUG) - vtkm_disable_troublesome_thrust_warnings_var(CMAKE_CXX_FLAGS_MINSIZEREL) - vtkm_disable_troublesome_thrust_warnings_var(CMAKE_CXX_FLAGS_RELEASE) - vtkm_disable_troublesome_thrust_warnings_var(CMAKE_CXX_FLAGS_RELWITHDEBINFO) -endmacro(vtkm_disable_troublesome_thrust_warnings) - -macro(vtkm_disable_troublesome_thrust_warnings_var flags_var) - set(old_flags "${${flags_var}}") - string(REPLACE "-Wshadow" "" new_flags "${old_flags}") - string(REPLACE "-Wunused-parameter" "" new_flags "${new_flags}") - string(REPLACE "-Wunused" "" new_flags "${new_flags}") - string(REPLACE "-Wextra" "" new_flags "${new_flags}") - string(REPLACE "-Wall" "" new_flags "${new_flags}") - set(${flags_var} "${new_flags}") -endmacro(vtkm_disable_troublesome_thrust_warnings_var) - -include(VTKmConfigureComponents) diff --git a/CMake/VTKmWrappers.cmake b/CMake/VTKmWrappers.cmake index 219f54398..8349338e8 100644 --- a/CMake/VTKmWrappers.cmake +++ b/CMake/VTKmWrappers.cmake @@ -141,7 +141,7 @@ function(vtkm_library) endif() set(lib_name ${VTKm_LIB_NAME}) - if(VTKm_ENABLE_CUDA) + if(TARGET vtkm::cuda) set_source_files_properties(${VTKm_LIB_WRAP_FOR_CUDA} PROPERTIES LANGUAGE "CUDA") endif() @@ -176,23 +176,37 @@ endfunction(vtkm_library) # Declare unit tests, which should be in the same directory as a kit # (package, module, whatever you call it). Usage: # -# [CUDA]: mark all source files as being compiled with the cuda compiler +# vtkm_unit_tests( +# NAME +# SOURCES +# BACKEND +# LIBRARIES +# TEST_ARGS +# +# ) +# # [BACKEND]: mark all source files as being compiled with the proper defines # to make this backend the default backend # If the backend is specified as CUDA it will also imply all # sources should be treated as CUDA sources # The backend name will also be added to the executable name # so you can test multiple backends easily -# vtkm_unit_tests( -# NAME -# CUDA -# SOURCES -# BACKEND -# LIBRARIES -# TEST_ARGS -# ) +# +# [LIBRARIES] : extra libraries that this set of tests need to link too +# +# [TEST_ARGS] : arguments that should be passed on the command line to the +# test executable +# +# Supported are documented below. These can be specified for +# all tests or for individual tests. When specifying these for individual tests, +# simply add them after the test name in the separated by a comma. +# e.g. `UnitTestMultiBlock,MPI`. +# +# Supported are +# * MPI : the test(s) will be executed using `mpirun`. +# function(vtkm_unit_tests) - set(options CUDA NO_TESTS) + set(options MPI) set(oneValueArgs BACKEND NAME) set(multiValueArgs SOURCES LIBRARIES TEST_ARGS) cmake_parse_arguments(VTKm_UT @@ -204,12 +218,11 @@ function(vtkm_unit_tests) return() endif() + vtkm_parse_test_options(VTKm_UT_SOURCES "${options}" ${VTKm_UT_SOURCES}) + set(backend ) if(VTKm_UT_BACKEND) set(backend "_${VTKm_UT_BACKEND}") - if(backend STREQUAL "CUDA") - set(VTKm_UT_CUDA "TRUE") - endif() endif() vtkm_get_kit_name(kit) @@ -219,6 +232,9 @@ function(vtkm_unit_tests) set(test_prog "${VTKm_UT_NAME}${backend}") endif() + if(VTKm_UT_BACKEND STREQUAL "CUDA") + set_source_files_properties(${VTKm_UT_SOURCES} PROPERTIES LANGUAGE "CUDA") + endif() create_test_sourcelist(TestSources ${test_prog}.cxx ${VTKm_UT_SOURCES}) @@ -238,7 +254,7 @@ function(vtkm_unit_tests) #determine the timeout for all the tests based on the backend. CUDA tests #generally require more time because of kernel generation. set(timeout 180) - if(VTKm_UT_CUDA) + if(VTKm_UT_BACKEND STREQUAL "CUDA") set(timeout 1500) endif() foreach (test ${VTKm_UT_SOURCES}) @@ -251,20 +267,36 @@ function(vtkm_unit_tests) endfunction(vtkm_unit_tests) -#----------------------------------------------------------------------------- -# Declare benchmarks, which use all the same infrastructure as tests but -# don't actually do the add_test at the end +# ----------------------------------------------------------------------------- +# vtkm_parse_test_options(varname options) +# INTERNAL: Parse options specified for individual tests. # -# [BACKEND]: mark all source files as being compiled with the proper defines -# to make this backend the default backend -# If the backend is specified as CUDA it will also imply all -# sources should be treated as CUDA sources -# The backend name will also be added to the executable name -# so you can test multiple backends easily -# vtkm_benchmarks( -# SOURCES -# BACKEND -# LIBRARIES -function(vtkm_benchmarks) - vtkm_unit_tests(NAME Benchmarks NO_TESTS ${ARGN}) -endfunction(vtkm_benchmarks) +# Parses the arguments to separate out options specified after the test name +# separated by a comma e.g. +# +# TestName,Option1,Option2 +# +# For every option in options, this will set _TestName_Option1, +# _TestName_Option2, etc in the parent scope. +# +function(vtkm_parse_test_options varname options) + set(names) + foreach(arg IN LISTS ARGN) + set(test_name ${arg}) + set(test_options) + if(test_name AND "x${test_name}" MATCHES "^x([^,]*),(.*)$") + set(test_name "${CMAKE_MATCH_1}") + string(REPLACE "," ";" test_options "${CMAKE_MATCH_2}") + endif() + foreach(opt IN LISTS test_options) + list(FIND options "${opt}" index) + if(index EQUAL -1) + message(WARNING "Unknown option '${opt}' specified for test '${test_name}'") + else() + set(_${test_name}_${opt} TRUE PARENT_SCOPE) + endif() + endforeach() + list(APPEND names ${test_name}) + endforeach() + set(${varname} ${names} PARENT_SCOPE) +endfunction() diff --git a/CMakeLists.txt b/CMakeLists.txt index eb0d2b736..3dc842bbd 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -75,6 +75,7 @@ option(VTKm_ENABLE_TBB "Enable TBB support" OFF) option(VTKm_ENABLE_RENDERING "Enable rendering library" ON) option(VTKm_ENABLE_TESTING "Enable VTKm Testing" ON) option(VTKm_ENABLE_BENCHMARKS "Enable VTKm Benchmarking" OFF) +option(VTKm_ENABLE_MPI "Enable MPI support" OFF) option(VTKm_ENABLE_DOCUMENTATION "Build Doxygen documentation" OFF) option(VTKm_ENABLE_EXAMPLES "Build examples" OFF) @@ -181,6 +182,11 @@ find_package(Pyexpander) #----------------------------------------------------------------------------- # Add subdirectories +if(VTKm_ENABLE_MPI) + # This `if` is temporary and will be removed once `diy` supports building + # without MPI. + add_subdirectory(diy) +endif() add_subdirectory(vtkm) #----------------------------------------------------------------------------- diff --git a/diy/CMakeLists.txt b/diy/CMakeLists.txt new file mode 100644 index 000000000..0b3fb112f --- /dev/null +++ b/diy/CMakeLists.txt @@ -0,0 +1,65 @@ +##============================================================================= +## +## Copyright (c) Kitware, Inc. +## All rights reserved. +## See LICENSE.txt for details. +## +## This software is distributed WITHOUT ANY WARRANTY; without even +## the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR +## PURPOSE. See the above copyright notice for more information. +## +## Copyright 2017 National Technology & Engineering Solutions of Sandia, LLC (NTESS). +## Copyright 2017 UT-Battelle, LLC. +## Copyright 2017 Los Alamos National Security. +## +## Under the terms of Contract DE-NA0003525 with NTESS, +## the U.S. Government retains certain rights in this software. +## Under the terms of Contract DE-AC52-06NA25396 with Los Alamos National +## Laboratory (LANL), the U.S. Government retains certain rights in +## this software. +## +##============================================================================= + +#============================================================================== +# See License.txt +#============================================================================== +add_library(diy INTERFACE) + +# diy needs C++11 +target_compile_features(diy INTERFACE cxx_auto_type) + +target_include_directories(diy INTERFACE + $ + $) + +# presently, this dependency is required. Make it optional in the future. +set(arg) +foreach(apath IN LISTS MPI_C_INCLUDE_PATH MPI_CXX_INCLUDE_PATH) + list(APPEND arg $) +endforeach() +target_include_directories(diy INTERFACE ${arg}) + +target_link_libraries(diy INTERFACE + $ + $) + +if(MPI_C_COMPILE_DEFINITIONS) + target_compile_definitions(diy INTERFACE + $<$:${MPI_C_COMPILE_DEFINITIONS}>) +endif() +if(MPI_CXX_COMPILE_DEFNITIONS) + target_compile_definitions(diy INTERFACE + $<$:${MPI_CXX_COMPILE_DEFNITIONS>) +endif() + +install(TARGETS diy + EXPORT ${VTKm_EXPORT_NAME}) + +# Install headers +install(DIRECTORY include/diy + DESTINATION ${VTKm_INSTALL_INCLUDE_DIR}) + +# Install other files. +install(FILES LEGAL.txt LICENSE.txt + DESTINATION ${VTKm_INSTALL_INCLUDE_DIR}/diy + ) diff --git a/diy/LEGAL.txt b/diy/LEGAL.txt new file mode 100644 index 000000000..66955ef03 --- /dev/null +++ b/diy/LEGAL.txt @@ -0,0 +1,19 @@ +Copyright Notice + +DIY2, Copyright (c) 2015, The Regents of the University of California, through +Lawrence Berkeley National Laboratory (subject to receipt of any required +approvals from the U.S. Dept. of Energy). All rights reserved. + +If you have questions about your rights to use or distribute this software, +please contact Berkeley Lab's Technology Transfer Department at TTD@lbl.gov. + +NOTICE. This software is owned by the U.S. Department of Energy. As such, the +U.S. Government has been granted for itself and others acting on its behalf a +paid-up, nonexclusive, irrevocable, worldwide license in the Software to +reproduce, prepare derivative works, and perform publicly and display publicly. +Beginning five (5) years after the date permission to assert copyright is +obtained from the U.S. Department of Energy, and subject to any subsequent five +(5) year renewals, the U.S. Government is granted for itself and others acting +on its behalf a paid-up, nonexclusive, irrevocable, worldwide license in the +Software to reproduce, prepare derivative works, distribute copies to the +public, perform publicly and display publicly, and to permit others to do so. diff --git a/diy/LICENSE.txt b/diy/LICENSE.txt new file mode 100644 index 000000000..7607d2ca1 --- /dev/null +++ b/diy/LICENSE.txt @@ -0,0 +1,41 @@ +License Agreement + +"DIY2, Copyright (c) 2015, The Regents of the University of California, through +Lawrence Berkeley National Laboratory (subject to receipt of any required +approvals from the U.S. Dept. of Energy). All rights reserved." + +Redistribution and use in source and binary forms, with or without modification, +are permitted provided that the following conditions are met: + +(1) Redistributions of source code must retain the above copyright notice, this +list of conditions and the following disclaimer. + +(2) Redistributions in binary form must reproduce the above copyright notice, +this list of conditions and the following disclaimer in the documentation and/or +other materials provided with the distribution. + +(3) Neither the name of the University of California, Lawrence Berkeley National +Laboratory, U.S. Dept. of Energy nor the names of its contributors may be used +to endorse or promote products derived from this software without specific prior +written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR +ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON +ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +You are under no obligation whatsoever to provide any bug fixes, patches, or +upgrades to the features, functionality or performance of the source code +("Enhancements") to anyone; however, if you choose to make your Enhancements +available either publicly, or directly to Lawrence Berkeley National Laboratory, +without imposing a separate written license agreement for such Enhancements, +then you hereby grant the following license: a non-exclusive, royalty-free +perpetual license to install, use, modify, prepare derivative works, incorporate +into other computer software, distribute, and sublicense such enhancements or +derivative works thereof, in binary and source code form. diff --git a/diy/include/diy/algorithms.hpp b/diy/include/diy/algorithms.hpp new file mode 100644 index 000000000..23215a2c3 --- /dev/null +++ b/diy/include/diy/algorithms.hpp @@ -0,0 +1,191 @@ +#ifndef DIY_ALGORITHMS_HPP +#define DIY_ALGORITHMS_HPP + +#include + +#include "master.hpp" +#include "assigner.hpp" +#include "reduce.hpp" +#include "reduce-operations.hpp" +#include "partners/swap.hpp" + +#include "detail/algorithms/sort.hpp" +#include "detail/algorithms/kdtree.hpp" +#include "detail/algorithms/kdtree-sampling.hpp" + +#include "log.hpp" + +namespace diy +{ + /** + * \ingroup Algorithms + * \brief sample sort `values` of each block, store the boundaries between blocks in `samples` + */ + template + void sort(Master& master, //!< master object + const Assigner& assigner, //!< assigner object + std::vector Block::* values, //!< all values to sort + std::vector Block::* samples, //!< (output) boundaries of blocks + size_t num_samples, //!< desired number of samples + const Cmp& cmp, //!< comparison function + int k = 2, //!< k-ary reduction will be used + bool samples_only = false) //!< false: results will be all_to_all exchanged; true: only sort but don't exchange results + { + bool immediate = master.immediate(); + master.set_immediate(false); + + // NB: although sorter will go out of scope, its member functions sample() + // and exchange() will return functors whose copies get saved inside reduce + detail::SampleSort sorter(values, samples, cmp, num_samples); + + // swap-reduce to all-gather samples + RegularDecomposer decomposer(1, interval(0,assigner.nblocks()), assigner.nblocks()); + RegularSwapPartners partners(decomposer, k); + reduce(master, assigner, partners, sorter.sample(), detail::SkipIntermediate(partners.rounds())); + + // all_to_all to exchange the values + if (!samples_only) + all_to_all(master, assigner, sorter.exchange(), k); + + master.set_immediate(immediate); + } + + + /** + * \ingroup Algorithms + * \brief sample sort `values` of each block, store the boundaries between blocks in `samples` + * shorter version of above sort algorithm with the default less-than comparator used for T + * and all_to_all exchange included + */ + template + void sort(Master& master, //!< master object + const Assigner& assigner, //!< assigner object + std::vector Block::* values, //!< all values to sort + std::vector Block::* samples, //!< (output) boundaries of blocks + size_t num_samples, //!< desired number of samples + int k = 2) //!< k-ary reduction will be used + { + sort(master, assigner, values, samples, num_samples, std::less(), k); + } + + /** + * \ingroup Algorithms + * \brief build a kd-tree and sort a set of points into it (use histograms to determine split values) + */ + template + void kdtree(Master& master, //!< master object + const Assigner& assigner, //!< assigner object + int dim, //!< dimensionality + const ContinuousBounds& domain, //!< global data extents + std::vector Block::* points, //!< input points to sort into kd-tree + size_t bins, //!< number of histogram bins for splitting a dimension + bool wrap = false)//!< periodic boundaries in all dimensions + { + if (assigner.nblocks() & (assigner.nblocks() - 1)) + throw std::runtime_error(fmt::format("KD-tree requires a number of blocks that's a power of 2, got {}", assigner.nblocks())); + + typedef diy::RegularContinuousLink RCLink; + + for (size_t i = 0; i < master.size(); ++i) + { + RCLink* link = static_cast(master.link(i)); + *link = RCLink(dim, domain, domain); + + if (wrap) // set up the links to self + { + diy::BlockID self = { master.gid(i), master.communicator().rank() }; + for (int j = 0; j < dim; ++j) + { + diy::Direction dir, wrap_dir; + + // left + dir[j] = -1; wrap_dir[j] = -1; + link->add_neighbor(self); + link->add_bounds(domain); + link->add_direction(dir); + link->add_wrap(wrap_dir); + + // right + dir[j] = 1; wrap_dir[j] = 1; + link->add_neighbor(self); + link->add_bounds(domain); + link->add_direction(dir); + link->add_wrap(wrap_dir); + } + } + } + + detail::KDTreePartition kdtree_partition(dim, points, bins); + + detail::KDTreePartners partners(dim, assigner.nblocks(), wrap, domain); + reduce(master, assigner, partners, kdtree_partition); + + // update master.expected to match the links + int expected = 0; + for (size_t i = 0; i < master.size(); ++i) + expected += master.link(i)->size_unique(); + master.set_expected(expected); + } + + /** + * \ingroup Algorithms + * \brief build a kd-tree and sort a set of points into it (use sampling to determine split values) + */ + template + void kdtree_sampling + (Master& master, //!< master object + const Assigner& assigner, //!< assigner object + int dim, //!< dimensionality + const ContinuousBounds& domain, //!< global data extents + std::vector Block::* points, //!< input points to sort into kd-tree + size_t samples, //!< number of samples to take in each block + bool wrap = false)//!< periodic boundaries in all dimensions + { + if (assigner.nblocks() & (assigner.nblocks() - 1)) + throw std::runtime_error(fmt::format("KD-tree requires a number of blocks that's a power of 2, got {}", assigner.nblocks())); + + typedef diy::RegularContinuousLink RCLink; + + for (size_t i = 0; i < master.size(); ++i) + { + RCLink* link = static_cast(master.link(i)); + *link = RCLink(dim, domain, domain); + + if (wrap) // set up the links to self + { + diy::BlockID self = { master.gid(i), master.communicator().rank() }; + for (int j = 0; j < dim; ++j) + { + diy::Direction dir, wrap_dir; + + // left + dir[j] = -1; wrap_dir[j] = -1; + link->add_neighbor(self); + link->add_bounds(domain); + link->add_direction(dir); + link->add_wrap(wrap_dir); + + // right + dir[j] = 1; wrap_dir[j] = 1; + link->add_neighbor(self); + link->add_bounds(domain); + link->add_direction(dir); + link->add_wrap(wrap_dir); + } + } + } + + detail::KDTreeSamplingPartition kdtree_partition(dim, points, samples); + + detail::KDTreePartners partners(dim, assigner.nblocks(), wrap, domain); + reduce(master, assigner, partners, kdtree_partition); + + // update master.expected to match the links + int expected = 0; + for (size_t i = 0; i < master.size(); ++i) + expected += master.link(i)->size_unique(); + master.set_expected(expected); + } +} + +#endif diff --git a/diy/include/diy/assigner.hpp b/diy/include/diy/assigner.hpp new file mode 100644 index 000000000..957596ddc --- /dev/null +++ b/diy/include/diy/assigner.hpp @@ -0,0 +1,126 @@ +#ifndef DIY_ASSIGNER_HPP +#define DIY_ASSIGNER_HPP + +#include + +namespace diy +{ + // Derived types should define + // int rank(int gid) const + // that converts a global block id to a rank that it's assigned to. + class Assigner + { + public: + /** + * \ingroup Assignment + * \brief Manages how blocks are assigned to processes + */ + Assigner(int size, //!< total number of processes + int nblocks //!< total (global) number of blocks + ): + size_(size), nblocks_(nblocks) {} + + //! returns the total number of process ranks + int size() const { return size_; } + //! returns the total number of global blocks + int nblocks() const { return nblocks_; } + //! sets the total number of global blocks + void set_nblocks(int nblocks) { nblocks_ = nblocks; } + //! gets the local gids for a given process rank + virtual void local_gids(int rank, std::vector& gids) const =0; + //! returns the process rank of the block with global id gid (need not be local) + virtual int rank(int gid) const =0; + + private: + int size_; // total number of ranks + int nblocks_; // total number of blocks + }; + + class ContiguousAssigner: public Assigner + { + public: + /** + * \ingroup Assignment + * \brief Assigns blocks to processes in contiguous gid (block global id) order + */ + ContiguousAssigner(int size, //!< total number of processes + int nblocks //!< total (global) number of blocks + ): + Assigner(size, nblocks) {} + + using Assigner::size; + using Assigner::nblocks; + + int rank(int gid) const override + { + int div = nblocks() / size(); + int mod = nblocks() % size(); + int r = gid / (div + 1); + if (r < mod) + { + return r; + } else + { + return mod + (gid - (div + 1)*mod)/div; + } + } + inline + void local_gids(int rank, std::vector& gids) const override; + }; + + class RoundRobinAssigner: public Assigner + { + public: + /** + * \ingroup Assignment + * \brief Assigns blocks to processes in cyclic or round-robin gid (block global id) order + */ + RoundRobinAssigner(int size, //!< total number of processes + int nblocks //!< total (global) number of blocks + ): + Assigner(size, nblocks) {} + + using Assigner::size; + using Assigner::nblocks; + + int rank(int gid) const override { return gid % size(); } + inline + void local_gids(int rank, std::vector& gids) const override; + }; +} + +void +diy::ContiguousAssigner:: +local_gids(int rank, std::vector& gids) const +{ + int div = nblocks() / size(); + int mod = nblocks() % size(); + + int from, to; + if (rank < mod) + from = rank * (div + 1); + else + from = mod * (div + 1) + (rank - mod) * div; + + if (rank + 1 < mod) + to = (rank + 1) * (div + 1); + else + to = mod * (div + 1) + (rank + 1 - mod) * div; + + for (int gid = from; gid < to; ++gid) + gids.push_back(gid); +} + +void +diy::RoundRobinAssigner:: +local_gids(int rank, std::vector& gids) const +{ + int cur = rank; + while (cur < nblocks()) + { + gids.push_back(cur); + cur += size(); + } +} + +#endif diff --git a/diy/include/diy/collection.hpp b/diy/include/diy/collection.hpp new file mode 100644 index 000000000..c24af95f5 --- /dev/null +++ b/diy/include/diy/collection.hpp @@ -0,0 +1,121 @@ +#ifndef DIY_COLLECTION_HPP +#define DIY_COLLECTION_HPP + +#include + +#include "serialization.hpp" +#include "storage.hpp" +#include "thread.hpp" + + +namespace diy +{ + class Collection + { + public: + typedef void* Element; + typedef std::vector Elements; + typedef critical_resource CInt; + + typedef void* (*Create)(); + typedef void (*Destroy)(void*); + typedef detail::Save Save; + typedef detail::Load Load; + + public: + Collection(Create create, + Destroy destroy, + ExternalStorage* storage, + Save save, + Load load): + create_(create), + destroy_(destroy), + storage_(storage), + save_(save), + load_(load), + in_memory_(0) {} + + size_t size() const { return elements_.size(); } + const CInt& in_memory() const { return in_memory_; } + inline void clear(); + + int add(Element e) { elements_.push_back(e); external_.push_back(-1); ++(*in_memory_.access()); return elements_.size() - 1; } + void* release(int i) { void* e = get(i); elements_[i] = 0; return e; } + + void* find(int i) const { return elements_[i]; } // possibly returns 0, if the element is unloaded + void* get(int i) { if (!find(i)) load(i); return find(i); } // loads the element first, and then returns its address + + int available() const { int i = 0; for (; i < (int)size(); ++i) if (find(i) != 0) break; return i; } + + inline void load(int i); + inline void unload(int i); + + Create creator() const { return create_; } + Destroy destroyer() const { return destroy_; } + Load loader() const { return load_; } + Save saver() const { return save_; } + + void* create() const { return create_(); } + void destroy(int i) { if (find(i)) { destroy_(find(i)); elements_[i] = 0; } else if (external_[i] != -1) storage_->destroy(external_[i]); } + + bool own() const { return destroy_ != 0; } + + ExternalStorage* storage() const { return storage_; } + + private: + Create create_; + Destroy destroy_; + ExternalStorage* storage_; + Save save_; + Load load_; + + Elements elements_; + std::vector external_; + CInt in_memory_; + }; +} + +void +diy::Collection:: +clear() +{ + if (own()) + for (size_t i = 0; i < size(); ++i) + destroy(i); + elements_.clear(); + external_.clear(); + *in_memory_.access() = 0; +} + +void +diy::Collection:: +unload(int i) +{ + //BinaryBuffer bb; + void* e = find(i); + //save_(e, bb); + //external_[i] = storage_->put(bb); + external_[i] = storage_->put(e, save_); + + destroy_(e); + elements_[i] = 0; + + --(*in_memory_.access()); +} + +void +diy::Collection:: +load(int i) +{ + //BinaryBuffer bb; + //storage_->get(external_[i], bb); + void* e = create_(); + //load_(e, bb); + storage_->get(external_[i], e, load_); + elements_[i] = e; + external_[i] = -1; + + ++(*in_memory_.access()); +} + +#endif diff --git a/diy/include/diy/communicator.hpp b/diy/include/diy/communicator.hpp new file mode 100644 index 000000000..b95708298 --- /dev/null +++ b/diy/include/diy/communicator.hpp @@ -0,0 +1,13 @@ +#ifndef DIY_COMMUNICATOR_HPP +#define DIY_COMMUNICATOR_HPP + +#warning "diy::Communicator (in diy/communicator.hpp) is deprecated, use diy::mpi::communicator directly" + +#include "mpi.hpp" + +namespace diy +{ + typedef mpi::communicator Communicator; +} + +#endif diff --git a/diy/include/diy/constants.h b/diy/include/diy/constants.h new file mode 100644 index 000000000..e3c9cc563 --- /dev/null +++ b/diy/include/diy/constants.h @@ -0,0 +1,22 @@ +#ifndef DIY_CONSTANTS_H +#define DIY_CONSTANTS_H + +// Default DIY_MAX_DIM to 4, unless provided by the user +// (used for static min/max size in various Bounds) +#ifndef DIY_MAX_DIM +#define DIY_MAX_DIM 4 +#endif + +enum +{ + DIY_X0 = 0x01, /* minimum-side x (left) neighbor */ + DIY_X1 = 0x02, /* maximum-side x (right) neighbor */ + DIY_Y0 = 0x04, /* minimum-side y (bottom) neighbor */ + DIY_Y1 = 0x08, /* maximum-side y (top) neighbor */ + DIY_Z0 = 0x10, /* minimum-side z (back) neighbor */ + DIY_Z1 = 0x20, /* maximum-side z (front)neighbor */ + DIY_T0 = 0x40, /* minimum-side t (earlier) neighbor */ + DIY_T1 = 0x80 /* maximum-side t (later) neighbor */ +}; + +#endif diff --git a/diy/include/diy/critical-resource.hpp b/diy/include/diy/critical-resource.hpp new file mode 100644 index 000000000..61a5a4b8a --- /dev/null +++ b/diy/include/diy/critical-resource.hpp @@ -0,0 +1,53 @@ +#ifndef DIY_CRITICAL_RESOURCE_HPP +#define DIY_CRITICAL_RESOURCE_HPP + +namespace diy +{ + // TODO: when not running under C++11, i.e., when lock_guard is TinyThread's + // lock_guard, and not C++11's unique_lock, this implementation might + // be buggy since the copy constructor is invoked when + // critical_resource::access() returns an instance of this class. Once + // the temporary is destroyed the mutex is unlocked. I'm not 100% + // certain of this because I'd expect a deadlock on copy constructor, + // but it's clearly not happening -- so I may be missing something. + // (This issue will take care of itself in DIY3 once we switch to C++11 completely.) + template + class resource_accessor + { + public: + resource_accessor(T& x, Mutex& m): + x_(x), lock_(m) {} + + T& operator*() { return x_; } + T* operator->() { return &x_; } + const T& operator*() const { return x_; } + const T* operator->() const { return &x_; } + + private: + T& x_; + lock_guard lock_; + }; + + template + class critical_resource + { + public: + typedef resource_accessor accessor; + typedef resource_accessor const_accessor; // eventually, try shared locking + + public: + critical_resource() {} + critical_resource(const T& x): + x_(x) {} + + accessor access() { return accessor(x_, m_); } + const_accessor const_access() const { return const_accessor(x_, m_); } + + private: + T x_; + mutable Mutex m_; + }; +} + + +#endif diff --git a/diy/include/diy/decomposition.hpp b/diy/include/diy/decomposition.hpp new file mode 100644 index 000000000..51dfc5af2 --- /dev/null +++ b/diy/include/diy/decomposition.hpp @@ -0,0 +1,716 @@ +#ifndef DIY_DECOMPOSITION_HPP +#define DIY_DECOMPOSITION_HPP + +#include +#include +#include +#include +#include +#include + +#include "link.hpp" +#include "assigner.hpp" +#include "master.hpp" + +namespace diy +{ +namespace detail +{ + template + struct BoundsHelper; + + // discrete bounds + template + struct BoundsHelper::value>::type> + { + using Coordinate = typename Bounds::Coordinate; + + static Coordinate from(int i, int n, Coordinate min, Coordinate max, bool) { return min + (max - min + 1)/n * i; } + static Coordinate to (int i, int n, Coordinate min, Coordinate max, bool shared_face) + { + if (i == n - 1) + return max; + else + return from(i+1, n, min, max, shared_face) - (shared_face ? 0 : 1); + } + + static int lower(Coordinate x, int n, Coordinate min, Coordinate max, bool shared) + { + Coordinate width = (max - min + 1)/n; + Coordinate res = (x - min)/width; + if (res >= n) res = n - 1; + + if (shared && x == from(res, n, min, max, shared)) + --res; + return res; + } + static int upper(Coordinate x, int n, Coordinate min, Coordinate max, bool shared) + { + Coordinate width = (max - min + 1)/n; + Coordinate res = (x - min)/width + 1; + if (shared && x == from(res, n, min, max, shared)) + ++res; + return res; + } + }; + + // continuous bounds + template + struct BoundsHelper::value>::type> + { + using Coordinate = typename Bounds::Coordinate; + + static Coordinate from(int i, int n, Coordinate min, Coordinate max, bool) { return min + (max - min)/n * i; } + static Coordinate to (int i, int n, Coordinate min, Coordinate max, bool) { return min + (max - min)/n * (i+1); } + + static int lower(Coordinate x, int n, Coordinate min, Coordinate max, bool) { Coordinate width = (max - min)/n; Coordinate res = std::floor((x - min)/width); if (min + res*width == x) return (res - 1); else return res; } + static int upper(Coordinate x, int n, Coordinate min, Coordinate max, bool) { Coordinate width = (max - min)/n; Coordinate res = std::ceil ((x - min)/width); if (min + res*width == x) return (res + 1); else return res; } + }; +} + + //! \ingroup Decomposition + //! Decomposes a regular (discrete or continuous) domain into even blocks; + //! creates Links with Bounds along the way. + template + struct RegularDecomposer + { + typedef Bounds_ Bounds; + typedef typename BoundsValue::type Coordinate; + typedef typename RegularLinkSelector::type Link; + + using Creator = std::function; + using Updater = std::function; + + typedef std::vector BoolVector; + typedef std::vector CoordinateVector; + typedef std::vector DivisionsVector; + + /// @param dim: dimensionality of the decomposition + /// @param domain: bounds of global domain + /// @param nblocks: total number of global blocks + /// @param share_face: indicates dimensions on which to share block faces + /// @param wrap: indicates dimensions on which to wrap the boundary + /// @param ghosts: indicates how many ghosts to use in each dimension + /// @param divisions: indicates how many cuts to make along each dimension + /// (0 means "no constraint," i.e., leave it up to the algorithm) + RegularDecomposer(int dim_, + const Bounds& domain_, + int nblocks_, + BoolVector share_face_ = BoolVector(), + BoolVector wrap_ = BoolVector(), + CoordinateVector ghosts_ = CoordinateVector(), + DivisionsVector divisions_ = DivisionsVector()): + dim(dim_), domain(domain_), nblocks(nblocks_), + share_face(share_face_), + wrap(wrap_), ghosts(ghosts_), divisions(divisions_) + { + if ((int) share_face.size() < dim) share_face.resize(dim); + if ((int) wrap.size() < dim) wrap.resize(dim); + if ((int) ghosts.size() < dim) ghosts.resize(dim); + if ((int) divisions.size() < dim) divisions.resize(dim); + + fill_divisions(divisions); + } + + // Calls create(int gid, const Bounds& bounds, const Link& link) + void decompose(int rank, const Assigner& assigner, const Creator& create); + + void decompose(int rank, const Assigner& assigner, Master& master, const Updater& update); + + void decompose(int rank, const Assigner& assigner, Master& master); + + // find lowest gid that owns a particular point + template + int lowest_gid(const Point& p) const; + + void gid_to_coords(int gid, DivisionsVector& coords) const { gid_to_coords(gid, coords, divisions); } + int coords_to_gid(const DivisionsVector& coords) const { return coords_to_gid(coords, divisions); } + void fill_divisions(std::vector& divisions) const; + + void fill_bounds(Bounds& bounds, const DivisionsVector& coords, bool add_ghosts = false) const; + void fill_bounds(Bounds& bounds, int gid, bool add_ghosts = false) const; + + static bool all(const std::vector& v, int x); + static void gid_to_coords(int gid, DivisionsVector& coords, const DivisionsVector& divisions); + static int coords_to_gid(const DivisionsVector& coords, const DivisionsVector& divisions); + + static void factor(std::vector& factors, int n); + + // Point to GIDs functions + template + void point_to_gids(std::vector& gids, const Point& p) const; + + //! returns gid of a block that contains the point; ignores ghosts + template + int point_to_gid(const Point& p) const; + + template + int num_gids(const Point& p) const; + + template + void top_bottom(int& top, int& bottom, const Point& p, int axis) const; + + + int dim; + Bounds domain; + int nblocks; + BoolVector share_face; + BoolVector wrap; + CoordinateVector ghosts; + DivisionsVector divisions; + + }; + + /** + * \ingroup Decomposition + * \brief Decomposes the domain into a prescribed pattern of blocks. + * + * @param dim dimension of the domain + * @param rank local rank + * @param assigner decides how processors are assigned to blocks (maps a gid to a rank) + * also communicates the total number of blocks + * @param create the callback functor + * @param wrap indicates dimensions on which to wrap the boundary + * @param ghosts indicates how many ghosts to use in each dimension + * @param divs indicates how many cuts to make along each dimension + * (0 means "no constraint," i.e., leave it up to the algorithm) + * + * `create(...)` is called with each block assigned to the local domain. See [decomposition example](#decomposition-example). + */ + template + void decompose(int dim, + int rank, + const Bounds& domain, + const Assigner& assigner, + const typename RegularDecomposer::Creator& create, + typename RegularDecomposer::BoolVector share_face = typename RegularDecomposer::BoolVector(), + typename RegularDecomposer::BoolVector wrap = typename RegularDecomposer::BoolVector(), + typename RegularDecomposer::CoordinateVector ghosts = typename RegularDecomposer::CoordinateVector(), + typename RegularDecomposer::DivisionsVector divs = typename RegularDecomposer::DivisionsVector()) + { + RegularDecomposer(dim, domain, assigner.nblocks(), share_face, wrap, ghosts, divs).decompose(rank, assigner, create); + } + + /** + * \ingroup Decomposition + * \brief Decomposes the domain into a prescribed pattern of blocks. + * + * @param dim dimension of the domain + * @param rank local rank + * @param assigner decides how processors are assigned to blocks (maps a gid to a rank) + * also communicates the total number of blocks + * @param master gets the blocks once this function returns + * @param wrap indicates dimensions on which to wrap the boundary + * @param ghosts indicates how many ghosts to use in each dimension + * @param divs indicates how many cuts to make along each dimension + * (0 means "no constraint," i.e., leave it up to the algorithm) + * + * `master` must have been supplied a create function in order for this function to work. + */ + template + void decompose(int dim, + int rank, + const Bounds& domain, + const Assigner& assigner, + Master& master, + typename RegularDecomposer::BoolVector share_face = typename RegularDecomposer::BoolVector(), + typename RegularDecomposer::BoolVector wrap = typename RegularDecomposer::BoolVector(), + typename RegularDecomposer::CoordinateVector ghosts = typename RegularDecomposer::CoordinateVector(), + typename RegularDecomposer::DivisionsVector divs = typename RegularDecomposer::DivisionsVector()) + { + RegularDecomposer(dim, domain, assigner.nblocks(), share_face, wrap, ghosts, divs).decompose(rank, assigner, master); + } + + /** + * \ingroup Decomposition + * \brief A "null" decompositon that simply creates the blocks and adds them to the master + * + * @param rank local rank + * @param assigner decides how processors are assigned to blocks (maps a gid to a rank) + * also communicates the total number of blocks + * @param master gets the blocks once this function returns + */ + inline + void decompose(int rank, + const Assigner& assigner, + Master& master) + { + std::vector local_gids; + assigner.local_gids(rank, local_gids); + + for (size_t i = 0; i < local_gids.size(); ++i) + master.add(local_gids[i], master.create(), new diy::Link); + } + + /** + * \ingroup Decomposition + * \brief Add a decomposition (modify links) of an existing set of blocks that were + * added to the master previously + * + * @param rank local rank + * @param assigner decides how processors are assigned to blocks (maps a gid to a rank) + * also communicates the total number of blocks + */ + template + void decompose(int dim, + int rank, + const Bounds& domain, + const Assigner& assigner, + Master& master, + const typename RegularDecomposer::Updater& update, + typename RegularDecomposer::BoolVector share_face = + typename RegularDecomposer::BoolVector(), + typename RegularDecomposer::BoolVector wrap = + typename RegularDecomposer::BoolVector(), + typename RegularDecomposer::CoordinateVector ghosts = + typename RegularDecomposer::CoordinateVector(), + typename RegularDecomposer::DivisionsVector divs = + typename RegularDecomposer::DivisionsVector()) + { + RegularDecomposer(dim, domain, assigner.nblocks(), share_face, wrap, ghosts, divs). + decompose(rank, assigner, master, update); + } + + //! Decomposition example: \example decomposition/test-decomposition.cpp + //! Direct master insertion example: \example decomposition/test-direct-master.cpp +} + +// decomposes domain and adds blocks to the master +template +void +diy::RegularDecomposer:: +decompose(int rank, const Assigner& assigner, Master& master) +{ + decompose(rank, assigner, [&master](int gid, const Bounds& core, const Bounds& bounds, const Bounds& domain, const Link& link) + { + void* b = master.create(); + Link* l = new Link(link); + master.add(gid, b, l); + }); +} + +template +void +diy::RegularDecomposer:: +decompose(int rank, const Assigner& assigner, const Creator& create) +{ + std::vector gids; + assigner.local_gids(rank, gids); + for (int i = 0; i < (int)gids.size(); ++i) + { + int gid = gids[i]; + + DivisionsVector coords; + gid_to_coords(gid, coords); + + Bounds core, bounds; + fill_bounds(core, coords); + fill_bounds(bounds, coords, true); + + // Fill link with all the neighbors + Link link(dim, core, bounds); + std::vector offsets(dim, -1); + offsets[0] = -2; + while (!all(offsets, 1)) + { + // next offset + int i; + for (i = 0; i < dim; ++i) + if (offsets[i] == 1) + offsets[i] = -1; + else + break; + ++offsets[i]; + + if (all(offsets, 0)) continue; // skip ourselves + + DivisionsVector nhbr_coords(dim); + Direction dir, wrap_dir; + bool inbounds = true; + for (int i = 0; i < dim; ++i) + { + nhbr_coords[i] = coords[i] + offsets[i]; + + // wrap + if (nhbr_coords[i] < 0) + { + if (wrap[i]) + { + nhbr_coords[i] = divisions[i] - 1; + wrap_dir[i] = -1; + } + else + inbounds = false; + } + + if (nhbr_coords[i] >= divisions[i]) + { + if (wrap[i]) + { + nhbr_coords[i] = 0; + wrap_dir[i] = 1; + } + else + inbounds = false; + } + + // NB: this needs to match the addressing scheme in dir_t (in constants.h) + if (offsets[i] == -1 || offsets[i] == 1) + dir[i] = offsets[i]; + } + if (!inbounds) continue; + + int nhbr_gid = coords_to_gid(nhbr_coords); + BlockID bid; bid.gid = nhbr_gid; bid.proc = assigner.rank(nhbr_gid); + link.add_neighbor(bid); + + Bounds nhbr_bounds; + fill_bounds(nhbr_bounds, nhbr_coords); + link.add_bounds(nhbr_bounds); + + link.add_direction(dir); + link.add_wrap(wrap_dir); + } + + create(gid, core, bounds, domain, link); + } +} + +// decomposes domain but does not add blocks to master, assumes they were added already +template +void +diy::RegularDecomposer:: +decompose(int rank, const Assigner& assigner, Master& master, const Updater& update) +{ + decompose(rank, assigner, [&master,&update](int gid, const Bounds& core, const Bounds& bounds, const Bounds& domain, const Link& link) + { + int lid = master.lid(gid); + Link* l = new Link(link); + master.replace_link(lid, l); + update(gid, lid, core, bounds, domain, *l); + }); +} + +template +bool +diy::RegularDecomposer:: +all(const std::vector& v, int x) +{ + for (unsigned i = 0; i < v.size(); ++i) + if (v[i] != x) + return false; + return true; +} + +template +void +diy::RegularDecomposer:: +gid_to_coords(int gid, DivisionsVector& coords, const DivisionsVector& divisions) +{ + int dim = divisions.size(); + for (int i = 0; i < dim; ++i) + { + coords.push_back(gid % divisions[i]); + gid /= divisions[i]; + } +} + +template +int +diy::RegularDecomposer:: +coords_to_gid(const DivisionsVector& coords, const DivisionsVector& divisions) +{ + int gid = 0; + for (int i = coords.size() - 1; i >= 0; --i) + { + gid *= divisions[i]; + gid += coords[i]; + } + return gid; +} + +//! \ingroup Decomposition +//! Gets the bounds, with or without ghosts, for a block specified by its block coordinates +template +void +diy::RegularDecomposer:: +fill_bounds(Bounds& bounds, //!< (output) bounds + const DivisionsVector& coords, //!< coordinates of the block in the decomposition + bool add_ghosts) //!< whether to include ghosts in the output bounds + const +{ + for (int i = 0; i < dim; ++i) + { + bounds.min[i] = detail::BoundsHelper::from(coords[i], divisions[i], domain.min[i], domain.max[i], share_face[i]); + bounds.max[i] = detail::BoundsHelper::to (coords[i], divisions[i], domain.min[i], domain.max[i], share_face[i]); + } + + for (int i = dim; i < DIY_MAX_DIM; ++i) // set the unused dimension to 0 + { + bounds.min[i] = 0; + bounds.max[i] = 0; + } + + if (!add_ghosts) + return; + + for (int i = 0; i < dim; ++i) + { + if (wrap[i]) + { + bounds.min[i] -= ghosts[i]; + bounds.max[i] += ghosts[i]; + } else + { + bounds.min[i] = std::max(domain.min[i], bounds.min[i] - ghosts[i]); + bounds.max[i] = std::min(domain.max[i], bounds.max[i] + ghosts[i]); + } + } +} + +//! \ingroup Decomposition +//! Gets the bounds, with or without ghosts, for a block specified by its gid +template +void +diy::RegularDecomposer:: +fill_bounds(Bounds& bounds, //!< (output) bounds + int gid, //!< global id of the block + bool add_ghosts) //!< whether to include ghosts in the output bounds + const +{ + DivisionsVector coords; + gid_to_coords(gid, coords); + if (add_ghosts) + fill_bounds(bounds, coords, true); + else + fill_bounds(bounds, coords); +} + +namespace diy { namespace detail { +// current state of division in one dimension used in fill_divisions below +template +struct Div +{ + int dim; // 0, 1, 2, etc. e.g. for x, y, z etc. + int nb; // number of blocks so far in this dimension + Coordinate b_size; // block size so far in this dimension + + // sort on descending block size unless tied, in which case + // sort on ascending num blocks in current dim unless tied, in which case + // sort on ascending dimension + bool operator<(Div rhs) const + { + // sort on second value of the pair unless tied, in which case sort on first + if (b_size == rhs.b_size) + { + if (nb == rhs.nb) + return(dim < rhs.dim); + return(nb < rhs.nb); + } + return(b_size > rhs.b_size); + } +}; +} } + +template +void +diy::RegularDecomposer:: +fill_divisions(std::vector& divisions) const +{ + // prod = number of blocks unconstrained by user; c = number of unconstrained dimensions + int prod = 1; int c = 0; + for (int i = 0; i < dim; ++i) + if (divisions[i] != 0) + { + prod *= divisions[i]; + ++c; + } + + if (nblocks % prod != 0) + throw std::runtime_error("Total number of blocks cannot be factored into provided divs"); + + if (c == (int) divisions.size()) // nothing to do; user provided all divs + return; + + // factor number of blocks left in unconstrained dimensions + // factorization is sorted from smallest to largest factors + std::vector factors; + factor(factors, nblocks/prod); + + using detail::Div; + std::vector< Div > missing_divs; // pairs consisting of (dim, #divs) + + // init missing_divs + for (int i = 0; i < dim; i++) + { + if (divisions[i] == 0) + { + Div div; + div.dim = i; + div.nb = 1; + div.b_size = domain.max[i] - domain.min[i]; + missing_divs.push_back(div); + } + } + + // iterate over factorization of number of blocks (factors are sorted smallest to largest) + // NB: using int instead of size_t because must be negative in order to break out of loop + for (int i = factors.size() - 1; i >= 0; --i) + { + // fill in missing divs by dividing dimension w/ largest block size + // except when this would be illegal (resulting in bounds.max < bounds.min; + // only a problem for discrete bounds + + // sort on decreasing block size + std::sort(missing_divs.begin(), missing_divs.end()); + + // split the dimension with the largest block size (first element in vector) + Coordinate min = + detail::BoundsHelper::from(0, + missing_divs[0].nb * factors[i], + domain.min[missing_divs[0].dim], + domain.max[missing_divs[0].dim], + share_face[missing_divs[0].dim]); + Coordinate max = + detail::BoundsHelper::to(0, + missing_divs[0].nb * factors[i], + domain.min[missing_divs[0].dim], + domain.max[missing_divs[0].dim], + share_face[missing_divs[0].dim]); + if (max >= min) + { + missing_divs[0].nb *= factors[i]; + missing_divs[0].b_size = max - min; + } + else + { + std::ostringstream oss; + oss << "Unable to decompose domain into " << nblocks << " blocks: " << min << " " << max; + throw std::runtime_error(oss.str()); + } + } + + // assign the divisions + for (size_t i = 0; i < missing_divs.size(); i++) + divisions[missing_divs[i].dim] = missing_divs[i].nb; +} + +template +void +diy::RegularDecomposer:: +factor(std::vector& factors, int n) +{ + while (n != 1) + for (int i = 2; i <= n; ++i) + { + if (n % i == 0) + { + factors.push_back(i); + n /= i; + break; + } + } +} + +// Point to GIDs +// TODO: deal with wrap correctly +// TODO: add an optional ghosts argument to ignore ghosts (if we want to find the true owners, or something like that) +template +template +void +diy::RegularDecomposer:: +point_to_gids(std::vector& gids, const Point& p) const +{ + std::vector< std::pair > ranges(dim); + for (int i = 0; i < dim; ++i) + top_bottom(ranges[i].second, ranges[i].first, p, i); + + // look up gids for all combinations + DivisionsVector coords(dim), location(dim); + while(location.back() < ranges.back().second - ranges.back().first) + { + for (int i = 0; i < dim; ++i) + coords[i] = ranges[i].first + location[i]; + gids.push_back(coords_to_gid(coords, divisions)); + + location[0]++; + unsigned i = 0; + while (i < dim-1 && location[i] == ranges[i].second - ranges[i].first) + { + location[i] = 0; + ++i; + location[i]++; + } + } +} + +template +template +int +diy::RegularDecomposer:: +point_to_gid(const Point& p) const +{ + int gid = 0; + for (int axis = dim - 1; axis >= 0; --axis) + { + int bottom = detail::BoundsHelper::lower(p[axis], divisions[axis], domain.min[axis], domain.max[axis], share_face[axis]); + bottom = std::max(0, bottom); + + // coupled with coords_to_gid + gid *= divisions[axis]; + gid += bottom; + } + + return gid; +} + +template +template +int +diy::RegularDecomposer:: +num_gids(const Point& p) const +{ + int res = 1; + for (int i = 0; i < dim; ++i) + { + int top, bottom; + top_bottom(top, bottom, p, i); + res *= top - bottom; + } + return res; +} + +template +template +void +diy::RegularDecomposer:: +top_bottom(int& top, int& bottom, const Point& p, int axis) const +{ + Coordinate l = p[axis] - ghosts[axis]; + Coordinate r = p[axis] + ghosts[axis]; + + top = detail::BoundsHelper::upper(r, divisions[axis], domain.min[axis], domain.max[axis], share_face[axis]); + bottom = detail::BoundsHelper::lower(l, divisions[axis], domain.min[axis], domain.max[axis], share_face[axis]); + + if (!wrap[axis]) + { + bottom = std::max(0, bottom); + top = std::min(divisions[axis], top); + } +} + +// find lowest gid that owns a particular point +template +template +int +diy::RegularDecomposer:: +lowest_gid(const Point& p) const +{ + // TODO: optimize - no need to compute all gids + std::vector gids; + point_to_gids(gids, p); + std::sort(gids.begin(), gids.end()); + return gids[0]; +} + +#endif diff --git a/diy/include/diy/detail/algorithms/kdtree-sampling.hpp b/diy/include/diy/detail/algorithms/kdtree-sampling.hpp new file mode 100644 index 000000000..7cf2ee1e5 --- /dev/null +++ b/diy/include/diy/detail/algorithms/kdtree-sampling.hpp @@ -0,0 +1,450 @@ +#ifndef DIY_DETAIL_ALGORITHMS_KDTREE_SAMPLING_HPP +#define DIY_DETAIL_ALGORITHMS_KDTREE_SAMPLING_HPP + +#include +#include +#include "../../partners/all-reduce.hpp" +#include "../../log.hpp" + +// TODO: technically, what's done now is not a perfect subsample: +// we take the same number of samples from every block, in reality this number should be selected at random, +// so that the total number of samples adds up to samples*nblocks +// +// NB: random samples are chosen using rand(), which is assumed to be seeded +// externally. Once we switch to C++11, we should use its more advanced +// random number generators (and take a generator as an external parameter) +// (TODO) + +namespace diy +{ +namespace detail +{ + +template +struct KDTreeSamplingPartition +{ + typedef diy::RegularContinuousLink RCLink; + typedef diy::ContinuousBounds Bounds; + + typedef std::vector Samples; + + KDTreeSamplingPartition(int dim, + std::vector Block::* points, + size_t samples): + dim_(dim), points_(points), samples_(samples) {} + + void operator()(Block* b, const diy::ReduceProxy& srp, const KDTreePartners& partners) const; + + int divide_gid(int gid, bool lower, int round, int rounds) const; + void update_links(Block* b, const diy::ReduceProxy& srp, int dim, int round, int rounds, bool wrap, const Bounds& domain) const; + void split_to_neighbors(Block* b, const diy::ReduceProxy& srp, int dim) const; + diy::Direction + find_wrap(const Bounds& bounds, const Bounds& nbr_bounds, const Bounds& domain) const; + + void compute_local_samples(Block* b, const diy::ReduceProxy& srp, int dim) const; + void add_samples(Block* b, const diy::ReduceProxy& srp, Samples& samples) const; + void receive_samples(Block* b, const diy::ReduceProxy& srp, Samples& samples) const; + void forward_samples(Block* b, const diy::ReduceProxy& srp, const Samples& samples) const; + + void enqueue_exchange(Block* b, const diy::ReduceProxy& srp, int dim, const Samples& samples) const; + void dequeue_exchange(Block* b, const diy::ReduceProxy& srp, int dim) const; + + void update_neighbor_bounds(Bounds& bounds, float split, int dim, bool lower) const; + bool intersects(const Bounds& x, const Bounds& y, int dim, bool wrap, const Bounds& domain) const; + float find_split(const Bounds& changed, const Bounds& original) const; + + int dim_; + std::vector Block::* points_; + size_t samples_; +}; + +} +} + + +template +void +diy::detail::KDTreeSamplingPartition:: +operator()(Block* b, const diy::ReduceProxy& srp, const KDTreePartners& partners) const +{ + int dim; + if (srp.round() < partners.rounds()) + dim = partners.dim(srp.round()); + else + dim = partners.dim(srp.round() - 1); + + if (srp.round() == partners.rounds()) + update_links(b, srp, dim, partners.sub_round(srp.round() - 2), partners.swap_rounds(), partners.wrap, partners.domain); // -1 would be the "uninformative" link round + else if (partners.swap_round(srp.round()) && partners.sub_round(srp.round()) < 0) // link round + { + dequeue_exchange(b, srp, dim); // from the swap round + split_to_neighbors(b, srp, dim); + } + else if (partners.swap_round(srp.round())) + { + Samples samples; + receive_samples(b, srp, samples); + enqueue_exchange(b, srp, dim, samples); + } else if (partners.sub_round(srp.round()) == 0) + { + if (srp.round() > 0) + { + int prev_dim = dim - 1; + if (prev_dim < 0) + prev_dim += dim_; + update_links(b, srp, prev_dim, partners.sub_round(srp.round() - 2), partners.swap_rounds(), partners.wrap, partners.domain); // -1 would be the "uninformative" link round + } + + compute_local_samples(b, srp, dim); + } else if (partners.sub_round(srp.round()) < (int) partners.histogram.rounds()/2) // we are reusing partners class, so really we are talking about the samples rounds here + { + Samples samples; + add_samples(b, srp, samples); + srp.enqueue(srp.out_link().target(0), samples); + } else + { + Samples samples; + add_samples(b, srp, samples); + if (samples.size() != 1) + { + // pick the median + std::nth_element(samples.begin(), samples.begin() + samples.size()/2, samples.end()); + std::swap(samples[0], samples[samples.size()/2]); + //std::sort(samples.begin(), samples.end()); + //samples[0] = (samples[samples.size()/2] + samples[samples.size()/2 + 1])/2; + samples.resize(1); + } + forward_samples(b, srp, samples); + } +} + +template +int +diy::detail::KDTreeSamplingPartition:: +divide_gid(int gid, bool lower, int round, int rounds) const +{ + if (lower) + gid &= ~(1 << (rounds - 1 - round)); + else + gid |= (1 << (rounds - 1 - round)); + return gid; +} + +// round here is the outer iteration of the algorithm +template +void +diy::detail::KDTreeSamplingPartition:: +update_links(Block* b, const diy::ReduceProxy& srp, int dim, int round, int rounds, bool wrap, const Bounds& domain) const +{ + auto log = get_logger(); + int gid = srp.gid(); + int lid = srp.master()->lid(gid); + RCLink* link = static_cast(srp.master()->link(lid)); + + // (gid, dir) -> i + std::map, int> link_map; + for (int i = 0; i < link->size(); ++i) + link_map[std::make_pair(link->target(i).gid, link->direction(i))] = i; + + // NB: srp.enqueue(..., ...) should match the link + std::vector splits(link->size()); + for (int i = 0; i < link->size(); ++i) + { + float split; diy::Direction dir; + + int in_gid = link->target(i).gid; + while(srp.incoming(in_gid)) + { + srp.dequeue(in_gid, split); + srp.dequeue(in_gid, dir); + + // reverse dir + for (int j = 0; j < dim_; ++j) + dir[j] = -dir[j]; + + int k = link_map[std::make_pair(in_gid, dir)]; + log->trace("{} {} {} -> {}", in_gid, dir, split, k); + splits[k] = split; + } + } + + RCLink new_link(dim_, link->core(), link->core()); + + bool lower = !(gid & (1 << (rounds - 1 - round))); + + // fill out the new link + for (int i = 0; i < link->size(); ++i) + { + diy::Direction dir = link->direction(i); + //diy::Direction wrap_dir = link->wrap(i); // we don't use existing wrap, but restore it from scratch + if (dir[dim] != 0) + { + if ((dir[dim] < 0 && lower) || (dir[dim] > 0 && !lower)) + { + int nbr_gid = divide_gid(link->target(i).gid, !lower, round, rounds); + diy::BlockID nbr = { nbr_gid, srp.assigner().rank(nbr_gid) }; + new_link.add_neighbor(nbr); + + new_link.add_direction(dir); + + Bounds bounds = link->bounds(i); + update_neighbor_bounds(bounds, splits[i], dim, !lower); + new_link.add_bounds(bounds); + + if (wrap) + new_link.add_wrap(find_wrap(new_link.bounds(), bounds, domain)); + else + new_link.add_wrap(diy::Direction()); + } + } else // non-aligned side + { + for (int j = 0; j < 2; ++j) + { + int nbr_gid = divide_gid(link->target(i).gid, j == 0, round, rounds); + + Bounds bounds = link->bounds(i); + update_neighbor_bounds(bounds, splits[i], dim, j == 0); + + if (intersects(bounds, new_link.bounds(), dim, wrap, domain)) + { + diy::BlockID nbr = { nbr_gid, srp.assigner().rank(nbr_gid) }; + new_link.add_neighbor(nbr); + new_link.add_direction(dir); + new_link.add_bounds(bounds); + + if (wrap) + new_link.add_wrap(find_wrap(new_link.bounds(), bounds, domain)); + else + new_link.add_wrap(diy::Direction()); + } + } + } + } + + // add link to the dual block + int dual_gid = divide_gid(gid, !lower, round, rounds); + diy::BlockID dual = { dual_gid, srp.assigner().rank(dual_gid) }; + new_link.add_neighbor(dual); + + Bounds nbr_bounds = link->bounds(); // old block bounds + update_neighbor_bounds(nbr_bounds, find_split(new_link.bounds(), nbr_bounds), dim, !lower); + new_link.add_bounds(nbr_bounds); + + new_link.add_wrap(diy::Direction()); // dual block cannot be wrapped + + if (lower) + { + diy::Direction right; + right[dim] = 1; + new_link.add_direction(right); + } else + { + diy::Direction left; + left[dim] = -1; + new_link.add_direction(left); + } + + // update the link; notice that this won't conflict with anything since + // reduce is using its own notion of the link constructed through the + // partners + link->swap(new_link); +} + +template +void +diy::detail::KDTreeSamplingPartition:: +split_to_neighbors(Block* b, const diy::ReduceProxy& srp, int dim) const +{ + int lid = srp.master()->lid(srp.gid()); + RCLink* link = static_cast(srp.master()->link(lid)); + + // determine split + float split = find_split(link->core(), link->bounds()); + + for (int i = 0; i < link->size(); ++i) + { + srp.enqueue(link->target(i), split); + srp.enqueue(link->target(i), link->direction(i)); + } +} + +template +void +diy::detail::KDTreeSamplingPartition:: +compute_local_samples(Block* b, const diy::ReduceProxy& srp, int dim) const +{ + // compute and enqueue local samples + Samples samples; + size_t points_size = (b->*points_).size(); + size_t n = std::min(points_size, samples_); + samples.reserve(n); + for (size_t i = 0; i < n; ++i) + { + float x = (b->*points_)[rand() % points_size][dim]; + samples.push_back(x); + } + + srp.enqueue(srp.out_link().target(0), samples); +} + +template +void +diy::detail::KDTreeSamplingPartition:: +add_samples(Block* b, const diy::ReduceProxy& srp, Samples& samples) const +{ + // dequeue and combine the samples + for (int i = 0; i < srp.in_link().size(); ++i) + { + int nbr_gid = srp.in_link().target(i).gid; + + Samples smpls; + srp.dequeue(nbr_gid, smpls); + for (size_t i = 0; i < smpls.size(); ++i) + samples.push_back(smpls[i]); + } +} + +template +void +diy::detail::KDTreeSamplingPartition:: +receive_samples(Block* b, const diy::ReduceProxy& srp, Samples& samples) const +{ + srp.dequeue(srp.in_link().target(0).gid, samples); +} + +template +void +diy::detail::KDTreeSamplingPartition:: +forward_samples(Block* b, const diy::ReduceProxy& srp, const Samples& samples) const +{ + for (int i = 0; i < srp.out_link().size(); ++i) + srp.enqueue(srp.out_link().target(i), samples); +} + +template +void +diy::detail::KDTreeSamplingPartition:: +enqueue_exchange(Block* b, const diy::ReduceProxy& srp, int dim, const Samples& samples) const +{ + int lid = srp.master()->lid(srp.gid()); + RCLink* link = static_cast(srp.master()->link(lid)); + + int k = srp.out_link().size(); + + if (k == 0) // final round; nothing needs to be sent; this is actually redundant + return; + + // pick split points + float split = samples[0]; + + // subset and enqueue + std::vector< std::vector > out_points(srp.out_link().size()); + for (size_t i = 0; i < (b->*points_).size(); ++i) + { + float x = (b->*points_)[i][dim]; + int loc = x < split ? 0 : 1; + out_points[loc].push_back((b->*points_)[i]); + } + int pos = -1; + for (int i = 0; i < k; ++i) + { + if (srp.out_link().target(i).gid == srp.gid()) + { + (b->*points_).swap(out_points[i]); + pos = i; + } + else + srp.enqueue(srp.out_link().target(i), out_points[i]); + } + if (pos == 0) + link->core().max[dim] = split; + else + link->core().min[dim] = split; +} + +template +void +diy::detail::KDTreeSamplingPartition:: +dequeue_exchange(Block* b, const diy::ReduceProxy& srp, int dim) const +{ + int lid = srp.master()->lid(srp.gid()); + RCLink* link = static_cast(srp.master()->link(lid)); + + for (int i = 0; i < srp.in_link().size(); ++i) + { + int nbr_gid = srp.in_link().target(i).gid; + if (nbr_gid == srp.gid()) + continue; + + std::vector in_points; + srp.dequeue(nbr_gid, in_points); + for (size_t j = 0; j < in_points.size(); ++j) + { + if (in_points[j][dim] < link->core().min[dim] || in_points[j][dim] > link->core().max[dim]) + throw std::runtime_error(fmt::format("Dequeued {} outside [{},{}] ({})", + in_points[j][dim], link->core().min[dim], link->core().max[dim], dim)); + (b->*points_).push_back(in_points[j]); + } + } +} + +template +void +diy::detail::KDTreeSamplingPartition:: +update_neighbor_bounds(Bounds& bounds, float split, int dim, bool lower) const +{ + if (lower) + bounds.max[dim] = split; + else + bounds.min[dim] = split; +} + +template +bool +diy::detail::KDTreeSamplingPartition:: +intersects(const Bounds& x, const Bounds& y, int dim, bool wrap, const Bounds& domain) const +{ + if (wrap) + { + if (x.min[dim] == domain.min[dim] && y.max[dim] == domain.max[dim]) + return true; + if (y.min[dim] == domain.min[dim] && x.max[dim] == domain.max[dim]) + return true; + } + return x.min[dim] <= y.max[dim] && y.min[dim] <= x.max[dim]; +} + +template +float +diy::detail::KDTreeSamplingPartition:: +find_split(const Bounds& changed, const Bounds& original) const +{ + for (int i = 0; i < dim_; ++i) + { + if (changed.min[i] != original.min[i]) + return changed.min[i]; + if (changed.max[i] != original.max[i]) + return changed.max[i]; + } + assert(0); + return -1; +} + +template +diy::Direction +diy::detail::KDTreeSamplingPartition:: +find_wrap(const Bounds& bounds, const Bounds& nbr_bounds, const Bounds& domain) const +{ + diy::Direction wrap; + for (int i = 0; i < dim_; ++i) + { + if (bounds.min[i] == domain.min[i] && nbr_bounds.max[i] == domain.max[i]) + wrap[i] = -1; + if (bounds.max[i] == domain.max[i] && nbr_bounds.min[i] == domain.min[i]) + wrap[i] = 1; + } + return wrap; +} + + +#endif diff --git a/diy/include/diy/detail/algorithms/kdtree.hpp b/diy/include/diy/detail/algorithms/kdtree.hpp new file mode 100644 index 000000000..286929dc9 --- /dev/null +++ b/diy/include/diy/detail/algorithms/kdtree.hpp @@ -0,0 +1,569 @@ +#ifndef DIY_DETAIL_ALGORITHMS_KDTREE_HPP +#define DIY_DETAIL_ALGORITHMS_KDTREE_HPP + +#include +#include +#include "../../partners/all-reduce.hpp" +#include "../../log.hpp" + +namespace diy +{ +namespace detail +{ + +struct KDTreePartners; + +template +struct KDTreePartition +{ + typedef diy::RegularContinuousLink RCLink; + typedef diy::ContinuousBounds Bounds; + + typedef std::vector Histogram; + + KDTreePartition(int dim, + std::vector Block::* points, + size_t bins): + dim_(dim), points_(points), bins_(bins) {} + + void operator()(Block* b, const diy::ReduceProxy& srp, const KDTreePartners& partners) const; + + int divide_gid(int gid, bool lower, int round, int rounds) const; + void update_links(Block* b, const diy::ReduceProxy& srp, int dim, int round, int rounds, bool wrap, const Bounds& domain) const; + void split_to_neighbors(Block* b, const diy::ReduceProxy& srp, int dim) const; + diy::Direction + find_wrap(const Bounds& bounds, const Bounds& nbr_bounds, const Bounds& domain) const; + + void compute_local_histogram(Block* b, const diy::ReduceProxy& srp, int dim) const; + void add_histogram(Block* b, const diy::ReduceProxy& srp, Histogram& histogram) const; + void receive_histogram(Block* b, const diy::ReduceProxy& srp, Histogram& histogram) const; + void forward_histogram(Block* b, const diy::ReduceProxy& srp, const Histogram& histogram) const; + + void enqueue_exchange(Block* b, const diy::ReduceProxy& srp, int dim, const Histogram& histogram) const; + void dequeue_exchange(Block* b, const diy::ReduceProxy& srp, int dim) const; + + void update_neighbor_bounds(Bounds& bounds, float split, int dim, bool lower) const; + bool intersects(const Bounds& x, const Bounds& y, int dim, bool wrap, const Bounds& domain) const; + float find_split(const Bounds& changed, const Bounds& original) const; + + int dim_; + std::vector Block::* points_; + size_t bins_; +}; + +} +} + +struct diy::detail::KDTreePartners +{ + // bool = are we in a swap (vs histogram) round + // int = round within that partner + typedef std::pair RoundType; + typedef diy::ContinuousBounds Bounds; + + KDTreePartners(int dim, int nblocks, bool wrap_, const Bounds& domain_): + decomposer(1, interval(0,nblocks-1), nblocks), + histogram(decomposer, 2), + swap(decomposer, 2, false), + wrap(wrap_), + domain(domain_) + { + for (unsigned i = 0; i < swap.rounds(); ++i) + { + // fill histogram rounds + for (unsigned j = 0; j < histogram.rounds(); ++j) + { + rounds_.push_back(std::make_pair(false, j)); + dim_.push_back(i % dim); + if (j == histogram.rounds() / 2 - 1 - i) + j += 2*i; + } + + // fill swap round + rounds_.push_back(std::make_pair(true, i)); + dim_.push_back(i % dim); + + // fill link round + rounds_.push_back(std::make_pair(true, -1)); // (true, -1) signals link round + dim_.push_back(i % dim); + } + } + + size_t rounds() const { return rounds_.size(); } + size_t swap_rounds() const { return swap.rounds(); } + + int dim(int round) const { return dim_[round]; } + bool swap_round(int round) const { return rounds_[round].first; } + int sub_round(int round) const { return rounds_[round].second; } + + inline bool active(int round, int gid, const diy::Master& m) const + { + if (round == (int) rounds()) + return true; + else if (swap_round(round) && sub_round(round) < 0) // link round + return true; + else if (swap_round(round)) + return swap.active(sub_round(round), gid, m); + else + return histogram.active(sub_round(round), gid, m); + } + + inline void incoming(int round, int gid, std::vector& partners, const diy::Master& m) const + { + if (round == (int) rounds()) + link_neighbors(-1, gid, partners, m); + else if (swap_round(round) && sub_round(round) < 0) // link round + swap.incoming(sub_round(round - 1) + 1, gid, partners, m); + else if (swap_round(round)) + histogram.incoming(histogram.rounds(), gid, partners, m); + else + { + if (round > 0 && sub_round(round) == 0) + link_neighbors(-1, gid, partners, m); + else if (round > 0 && sub_round(round - 1) != sub_round(round) - 1) // jump through the histogram rounds + histogram.incoming(sub_round(round - 1) + 1, gid, partners, m); + else + histogram.incoming(sub_round(round), gid, partners, m); + } + } + + inline void outgoing(int round, int gid, std::vector& partners, const diy::Master& m) const + { + if (round == (int) rounds()) + swap.outgoing(sub_round(round-1) + 1, gid, partners, m); + else if (swap_round(round) && sub_round(round) < 0) // link round + link_neighbors(-1, gid, partners, m); + else if (swap_round(round)) + swap.outgoing(sub_round(round), gid, partners, m); + else + histogram.outgoing(sub_round(round), gid, partners, m); + } + + inline void link_neighbors(int, int gid, std::vector& partners, const diy::Master& m) const + { + int lid = m.lid(gid); + diy::Link* link = m.link(lid); + + std::set result; // partners must be unique + for (int i = 0; i < link->size(); ++i) + result.insert(link->target(i).gid); + + for (std::set::const_iterator it = result.begin(); it != result.end(); ++it) + partners.push_back(*it); + } + + // 1-D domain to feed into histogram and swap + diy::RegularDecomposer decomposer; + + diy::RegularAllReducePartners histogram; + diy::RegularSwapPartners swap; + + std::vector rounds_; + std::vector dim_; + + bool wrap; + Bounds domain; +}; + +template +void +diy::detail::KDTreePartition:: +operator()(Block* b, const diy::ReduceProxy& srp, const KDTreePartners& partners) const +{ + int dim; + if (srp.round() < partners.rounds()) + dim = partners.dim(srp.round()); + else + dim = partners.dim(srp.round() - 1); + + if (srp.round() == partners.rounds()) + update_links(b, srp, dim, partners.sub_round(srp.round() - 2), partners.swap_rounds(), partners.wrap, partners.domain); // -1 would be the "uninformative" link round + else if (partners.swap_round(srp.round()) && partners.sub_round(srp.round()) < 0) // link round + { + dequeue_exchange(b, srp, dim); // from the swap round + split_to_neighbors(b, srp, dim); + } + else if (partners.swap_round(srp.round())) + { + Histogram histogram; + receive_histogram(b, srp, histogram); + enqueue_exchange(b, srp, dim, histogram); + } else if (partners.sub_round(srp.round()) == 0) + { + if (srp.round() > 0) + { + int prev_dim = dim - 1; + if (prev_dim < 0) + prev_dim += dim_; + update_links(b, srp, prev_dim, partners.sub_round(srp.round() - 2), partners.swap_rounds(), partners.wrap, partners.domain); // -1 would be the "uninformative" link round + } + + compute_local_histogram(b, srp, dim); + } else if (partners.sub_round(srp.round()) < (int) partners.histogram.rounds()/2) + { + Histogram histogram(bins_); + add_histogram(b, srp, histogram); + srp.enqueue(srp.out_link().target(0), histogram); + } + else + { + Histogram histogram(bins_); + add_histogram(b, srp, histogram); + forward_histogram(b, srp, histogram); + } +} + +template +int +diy::detail::KDTreePartition:: +divide_gid(int gid, bool lower, int round, int rounds) const +{ + if (lower) + gid &= ~(1 << (rounds - 1 - round)); + else + gid |= (1 << (rounds - 1 - round)); + return gid; +} + +// round here is the outer iteration of the algorithm +template +void +diy::detail::KDTreePartition:: +update_links(Block* b, const diy::ReduceProxy& srp, int dim, int round, int rounds, bool wrap, const Bounds& domain) const +{ + int gid = srp.gid(); + int lid = srp.master()->lid(gid); + RCLink* link = static_cast(srp.master()->link(lid)); + + // (gid, dir) -> i + std::map, int> link_map; + for (int i = 0; i < link->size(); ++i) + link_map[std::make_pair(link->target(i).gid, link->direction(i))] = i; + + // NB: srp.enqueue(..., ...) should match the link + std::vector splits(link->size()); + for (int i = 0; i < link->size(); ++i) + { + float split; diy::Direction dir; + + int in_gid = link->target(i).gid; + while(srp.incoming(in_gid)) + { + srp.dequeue(in_gid, split); + srp.dequeue(in_gid, dir); + + // reverse dir + for (int j = 0; j < dim_; ++j) + dir[j] = -dir[j]; + + int k = link_map[std::make_pair(in_gid, dir)]; + splits[k] = split; + } + } + + RCLink new_link(dim_, link->core(), link->core()); + + bool lower = !(gid & (1 << (rounds - 1 - round))); + + // fill out the new link + for (int i = 0; i < link->size(); ++i) + { + diy::Direction dir = link->direction(i); + //diy::Direction wrap_dir = link->wrap(i); // we don't use existing wrap, but restore it from scratch + if (dir[dim] != 0) + { + if ((dir[dim] < 0 && lower) || (dir[dim] > 0 && !lower)) + { + int nbr_gid = divide_gid(link->target(i).gid, !lower, round, rounds); + diy::BlockID nbr = { nbr_gid, srp.assigner().rank(nbr_gid) }; + new_link.add_neighbor(nbr); + + new_link.add_direction(dir); + + Bounds bounds = link->bounds(i); + update_neighbor_bounds(bounds, splits[i], dim, !lower); + new_link.add_bounds(bounds); + + if (wrap) + new_link.add_wrap(find_wrap(new_link.bounds(), bounds, domain)); + else + new_link.add_wrap(diy::Direction()); + } + } else // non-aligned side + { + for (int j = 0; j < 2; ++j) + { + int nbr_gid = divide_gid(link->target(i).gid, j == 0, round, rounds); + + Bounds bounds = link->bounds(i); + update_neighbor_bounds(bounds, splits[i], dim, j == 0); + + if (intersects(bounds, new_link.bounds(), dim, wrap, domain)) + { + diy::BlockID nbr = { nbr_gid, srp.assigner().rank(nbr_gid) }; + new_link.add_neighbor(nbr); + new_link.add_direction(dir); + new_link.add_bounds(bounds); + + if (wrap) + new_link.add_wrap(find_wrap(new_link.bounds(), bounds, domain)); + else + new_link.add_wrap(diy::Direction()); + } + } + } + } + + // add link to the dual block + int dual_gid = divide_gid(gid, !lower, round, rounds); + diy::BlockID dual = { dual_gid, srp.assigner().rank(dual_gid) }; + new_link.add_neighbor(dual); + + Bounds nbr_bounds = link->bounds(); // old block bounds + update_neighbor_bounds(nbr_bounds, find_split(new_link.bounds(), nbr_bounds), dim, !lower); + new_link.add_bounds(nbr_bounds); + + new_link.add_wrap(diy::Direction()); // dual block cannot be wrapped + + if (lower) + { + diy::Direction right; + right[dim] = 1; + new_link.add_direction(right); + } else + { + diy::Direction left; + left[dim] = -1; + new_link.add_direction(left); + } + + // update the link; notice that this won't conflict with anything since + // reduce is using its own notion of the link constructed through the + // partners + link->swap(new_link); +} + +template +void +diy::detail::KDTreePartition:: +split_to_neighbors(Block* b, const diy::ReduceProxy& srp, int dim) const +{ + int lid = srp.master()->lid(srp.gid()); + RCLink* link = static_cast(srp.master()->link(lid)); + + // determine split + float split = find_split(link->core(), link->bounds()); + + for (int i = 0; i < link->size(); ++i) + { + srp.enqueue(link->target(i), split); + srp.enqueue(link->target(i), link->direction(i)); + } +} + +template +void +diy::detail::KDTreePartition:: +compute_local_histogram(Block* b, const diy::ReduceProxy& srp, int dim) const +{ + int lid = srp.master()->lid(srp.gid()); + RCLink* link = static_cast(srp.master()->link(lid)); + + // compute and enqueue local histogram + Histogram histogram(bins_); + + float width = (link->core().max[dim] - link->core().min[dim])/bins_; + for (size_t i = 0; i < (b->*points_).size(); ++i) + { + float x = (b->*points_)[i][dim]; + int loc = (x - link->core().min[dim]) / width; + if (loc < 0) + throw std::runtime_error(fmt::format("{} {} {}", loc, x, link->core().min[dim])); + if (loc >= (int) bins_) + loc = bins_ - 1; + ++(histogram[loc]); + } + + srp.enqueue(srp.out_link().target(0), histogram); +} + +template +void +diy::detail::KDTreePartition:: +add_histogram(Block* b, const diy::ReduceProxy& srp, Histogram& histogram) const +{ + // dequeue and add up the histograms + for (int i = 0; i < srp.in_link().size(); ++i) + { + int nbr_gid = srp.in_link().target(i).gid; + + Histogram hist; + srp.dequeue(nbr_gid, hist); + for (size_t i = 0; i < hist.size(); ++i) + histogram[i] += hist[i]; + } +} + +template +void +diy::detail::KDTreePartition:: +receive_histogram(Block* b, const diy::ReduceProxy& srp, Histogram& histogram) const +{ + srp.dequeue(srp.in_link().target(0).gid, histogram); +} + +template +void +diy::detail::KDTreePartition:: +forward_histogram(Block* b, const diy::ReduceProxy& srp, const Histogram& histogram) const +{ + for (int i = 0; i < srp.out_link().size(); ++i) + srp.enqueue(srp.out_link().target(i), histogram); +} + +template +void +diy::detail::KDTreePartition:: +enqueue_exchange(Block* b, const diy::ReduceProxy& srp, int dim, const Histogram& histogram) const +{ + auto log = get_logger(); + + int lid = srp.master()->lid(srp.gid()); + RCLink* link = static_cast(srp.master()->link(lid)); + + int k = srp.out_link().size(); + + if (k == 0) // final round; nothing needs to be sent; this is actually redundant + return; + + // pick split points + size_t total = 0; + for (size_t i = 0; i < histogram.size(); ++i) + total += histogram[i]; + log->trace("Histogram total: {}", total); + + size_t cur = 0; + float width = (link->core().max[dim] - link->core().min[dim])/bins_; + float split = 0; + for (size_t i = 0; i < histogram.size(); ++i) + { + if (cur + histogram[i] > total/2) + { + split = link->core().min[dim] + width*i; + break; + } + cur += histogram[i]; + } + log->trace("Found split: {} (dim={}) in {} - {}", split, dim, link->core().min[dim], link->core().max[dim]); + + // subset and enqueue + std::vector< std::vector > out_points(srp.out_link().size()); + for (size_t i = 0; i < (b->*points_).size(); ++i) + { + float x = (b->*points_)[i][dim]; + int loc = x < split ? 0 : 1; + out_points[loc].push_back((b->*points_)[i]); + } + int pos = -1; + for (int i = 0; i < k; ++i) + { + if (srp.out_link().target(i).gid == srp.gid()) + { + (b->*points_).swap(out_points[i]); + pos = i; + } + else + srp.enqueue(srp.out_link().target(i), out_points[i]); + } + if (pos == 0) + link->core().max[dim] = split; + else + link->core().min[dim] = split; +} + +template +void +diy::detail::KDTreePartition:: +dequeue_exchange(Block* b, const diy::ReduceProxy& srp, int dim) const +{ + int lid = srp.master()->lid(srp.gid()); + RCLink* link = static_cast(srp.master()->link(lid)); + + for (int i = 0; i < srp.in_link().size(); ++i) + { + int nbr_gid = srp.in_link().target(i).gid; + if (nbr_gid == srp.gid()) + continue; + + std::vector in_points; + srp.dequeue(nbr_gid, in_points); + for (size_t j = 0; j < in_points.size(); ++j) + { + if (in_points[j][dim] < link->core().min[dim] || in_points[j][dim] > link->core().max[dim]) + throw std::runtime_error(fmt::format("Dequeued {} outside [{},{}] ({})", + in_points[j][dim], link->core().min[dim], link->core().max[dim], dim)); + (b->*points_).push_back(in_points[j]); + } + } +} + +template +void +diy::detail::KDTreePartition:: +update_neighbor_bounds(Bounds& bounds, float split, int dim, bool lower) const +{ + if (lower) + bounds.max[dim] = split; + else + bounds.min[dim] = split; +} + +template +bool +diy::detail::KDTreePartition:: +intersects(const Bounds& x, const Bounds& y, int dim, bool wrap, const Bounds& domain) const +{ + if (wrap) + { + if (x.min[dim] == domain.min[dim] && y.max[dim] == domain.max[dim]) + return true; + if (y.min[dim] == domain.min[dim] && x.max[dim] == domain.max[dim]) + return true; + } + return x.min[dim] <= y.max[dim] && y.min[dim] <= x.max[dim]; +} + +template +float +diy::detail::KDTreePartition:: +find_split(const Bounds& changed, const Bounds& original) const +{ + for (int i = 0; i < dim_; ++i) + { + if (changed.min[i] != original.min[i]) + return changed.min[i]; + if (changed.max[i] != original.max[i]) + return changed.max[i]; + } + assert(0); + return -1; +} + +template +diy::Direction +diy::detail::KDTreePartition:: +find_wrap(const Bounds& bounds, const Bounds& nbr_bounds, const Bounds& domain) const +{ + diy::Direction wrap; + for (int i = 0; i < dim_; ++i) + { + if (bounds.min[i] == domain.min[i] && nbr_bounds.max[i] == domain.max[i]) + wrap[i] = -1; + if (bounds.max[i] == domain.max[i] && nbr_bounds.min[i] == domain.min[i]) + wrap[i] = 1; + } + return wrap; +} + + +#endif diff --git a/diy/include/diy/detail/algorithms/sort.hpp b/diy/include/diy/detail/algorithms/sort.hpp new file mode 100644 index 000000000..5cc3f8807 --- /dev/null +++ b/diy/include/diy/detail/algorithms/sort.hpp @@ -0,0 +1,162 @@ +#ifndef DIY_DETAIL_ALGORITHMS_SORT_HPP +#define DIY_DETAIL_ALGORITHMS_SORT_HPP + +#include +#include + +namespace diy +{ + +namespace detail +{ + +template +struct SampleSort +{ + typedef std::vector Block::*ValuesVector; + struct Sampler; + struct Exchanger; + + SampleSort(ValuesVector values_, ValuesVector samples_, const Cmp& cmp_, size_t num_samples_): + values(values_), samples(samples_), + cmp(cmp_), num_samples(num_samples_) {} + + Sampler sample() const { return Sampler(values, samples, cmp, num_samples); } + Exchanger exchange() const { return Exchanger(values, samples, cmp); } + + static void dequeue_values(std::vector& v, const ReduceProxy& rp, bool skip_self = true) + { + auto log = get_logger(); + + int k_in = rp.in_link().size(); + + log->trace("dequeue_values(): gid={}, round={}; v.size()={}", rp.gid(), rp.round(), v.size()); + + if (detail::is_default< Serialization >::value) + { + // add up sizes + size_t sz = 0; + size_t end = v.size(); + for (int i = 0; i < k_in; ++i) + { + log->trace(" incoming size from {}: {}", rp.in_link().target(i).gid, sz); + if (skip_self && rp.in_link().target(i).gid == rp.gid()) continue; + MemoryBuffer& in = rp.incoming(rp.in_link().target(i).gid); + sz += in.size() / sizeof(T); + } + log->trace(" incoming size: {}", sz); + v.resize(end + sz); + + for (int i = 0; i < k_in; ++i) + { + if (skip_self && rp.in_link().target(i).gid == rp.gid()) continue; + MemoryBuffer& in = rp.incoming(rp.in_link().target(i).gid); + size_t sz = in.size() / sizeof(T); + T* bg = (T*) &in.buffer[0]; + std::copy(bg, bg + sz, &v[end]); + end += sz; + } + } else + { + for (int i = 0; i < k_in; ++i) + { + if (skip_self && rp.in_link().target(i).gid == rp.gid()) continue; + MemoryBuffer& in = rp.incoming(rp.in_link().target(i).gid); + while(in) + { + T x; + diy::load(in, x); + v.emplace_back(std::move(x)); + } + } + } + log->trace(" v.size()={}", v.size()); + } + + ValuesVector values; + ValuesVector samples; + Cmp cmp; + size_t num_samples; +}; + +template +struct SampleSort::Sampler +{ + Sampler(ValuesVector values_, ValuesVector dividers_, const Cmp& cmp_, size_t num_samples_): + values(values_), dividers(dividers_), cmp(cmp_), num_samples(num_samples_) {} + + void operator()(Block* b, const ReduceProxy& srp, const RegularSwapPartners& partners) const + { + int k_in = srp.in_link().size(); + int k_out = srp.out_link().size(); + + std::vector samples; + + if (k_in == 0) + { + // draw random samples + for (size_t i = 0; i < num_samples; ++i) + samples.push_back((b->*values)[std::rand() % (b->*values).size()]); + } else + dequeue_values(samples, srp, false); + + if (k_out == 0) + { + // pick subsamples that separate quantiles + std::sort(samples.begin(), samples.end(), cmp); + std::vector subsamples(srp.nblocks() - 1); + int step = samples.size() / srp.nblocks(); // NB: subsamples.size() + 1 + for (size_t i = 0; i < subsamples.size(); ++i) + subsamples[i] = samples[(i+1)*step]; + (b->*dividers).swap(subsamples); + } + else + { + for (int i = 0; i < k_out; ++i) + { + MemoryBuffer& out = srp.outgoing(srp.out_link().target(i)); + save(out, &samples[0], samples.size()); + } + } + } + + ValuesVector values; + ValuesVector dividers; + Cmp cmp; + size_t num_samples; +}; + +template +struct SampleSort::Exchanger +{ + Exchanger(ValuesVector values_, ValuesVector samples_, const Cmp& cmp_): + values(values_), samples(samples_), cmp(cmp_) {} + + void operator()(Block* b, const ReduceProxy& rp) const + { + if (rp.round() == 0) + { + // enqueue values to the correct locations + for (size_t i = 0; i < (b->*values).size(); ++i) + { + int to = std::lower_bound((b->*samples).begin(), (b->*samples).end(), (b->*values)[i], cmp) - (b->*samples).begin(); + rp.enqueue(rp.out_link().target(to), (b->*values)[i]); + } + (b->*values).clear(); + } else + { + dequeue_values((b->*values), rp, false); + std::sort((b->*values).begin(), (b->*values).end(), cmp); + } + } + + ValuesVector values; + ValuesVector samples; + Cmp cmp; +}; + +} + +} + +#endif diff --git a/diy/include/diy/detail/block_traits.hpp b/diy/include/diy/detail/block_traits.hpp new file mode 100644 index 000000000..eb4b7c547 --- /dev/null +++ b/diy/include/diy/detail/block_traits.hpp @@ -0,0 +1,31 @@ +#ifndef DIY_BLOCK_TRAITS_HPP +#define DIY_BLOCK_TRAITS_HPP + +#include "traits.hpp" + +namespace diy +{ +namespace detail +{ + template + struct block_traits + { + typedef typename std::remove_pointer::template arg<0>::type>::type type; + }; + + // matches block member functions + template + struct block_traits + { + typedef Block type; + }; + + template + struct block_traits + { + typedef Block type; + }; +} +} + +#endif diff --git a/diy/include/diy/detail/collectives.hpp b/diy/include/diy/detail/collectives.hpp new file mode 100644 index 000000000..a85a0f3e4 --- /dev/null +++ b/diy/include/diy/detail/collectives.hpp @@ -0,0 +1,54 @@ +#ifndef DIY_COLLECTIVES_HPP +#define DIY_COLLECTIVES_HPP + +namespace diy +{ +namespace detail +{ + struct CollectiveOp + { + virtual void init() =0; + virtual void update(const CollectiveOp& other) =0; + virtual void global(const mpi::communicator& comm) =0; + virtual void copy_from(const CollectiveOp& other) =0; + virtual void result_out(void* dest) const =0; + virtual ~CollectiveOp() {} + }; + + template + struct AllReduceOp: public CollectiveOp + { + AllReduceOp(const T& x, Op op): + in_(x), op_(op) {} + + void init() { out_ = in_; } + void update(const CollectiveOp& other) { out_ = op_(out_, static_cast(other).in_); } + void global(const mpi::communicator& comm) { T res; mpi::all_reduce(comm, out_, res, op_); out_ = res; } + void copy_from(const CollectiveOp& other) { out_ = static_cast(other).out_; } + void result_out(void* dest) const { *reinterpret_cast(dest) = out_; } + + private: + T in_, out_; + Op op_; + }; + + template + struct Scratch: public CollectiveOp + { + Scratch(const T& x): + x_(x) {} + + void init() {} + void update(const CollectiveOp& other) {} + void global(const mpi::communicator& comm) {} + void copy_from(const CollectiveOp& other) {} + void result_out(void* dest) const { *reinterpret_cast(dest) = x_; } + + private: + T x_; + }; + +} +} + +#endif diff --git a/diy/include/diy/detail/reduce/all-to-all.hpp b/diy/include/diy/detail/reduce/all-to-all.hpp new file mode 100644 index 000000000..1e555db82 --- /dev/null +++ b/diy/include/diy/detail/reduce/all-to-all.hpp @@ -0,0 +1,169 @@ +#ifndef DIY_DETAIL_ALL_TO_ALL_HPP +#define DIY_DETAIL_ALL_TO_ALL_HPP + +#include "../block_traits.hpp" + +namespace diy +{ + +namespace detail +{ + template + struct AllToAllReduce + { + using Block = typename block_traits::type; + + AllToAllReduce(const Op& op_, const Assigner& assigner): + op(op_) + { + for (int gid = 0; gid < assigner.nblocks(); ++gid) + { + BlockID nbr = { gid, assigner.rank(gid) }; + all_neighbors_link.add_neighbor(nbr); + } + } + + void operator()(Block* b, const ReduceProxy& srp, const RegularSwapPartners& partners) const + { + int k_in = srp.in_link().size(); + int k_out = srp.out_link().size(); + + if (k_in == 0 && k_out == 0) // special case of a single block + { + ReduceProxy all_srp_out(srp, srp.block(), 0, srp.assigner(), empty_link, all_neighbors_link); + ReduceProxy all_srp_in (srp, srp.block(), 1, srp.assigner(), all_neighbors_link, empty_link); + + op(b, all_srp_out); + MemoryBuffer& in_queue = all_srp_in.incoming(all_srp_in.in_link().target(0).gid); + in_queue.swap(all_srp_out.outgoing(all_srp_out.out_link().target(0))); + in_queue.reset(); + + op(b, all_srp_in); + return; + } + + if (k_in == 0) // initial round + { + ReduceProxy all_srp(srp, srp.block(), 0, srp.assigner(), empty_link, all_neighbors_link); + op(b, all_srp); + + Master::OutgoingQueues all_queues; + all_queues.swap(*all_srp.outgoing()); // clears out the queues and stores them locally + + // enqueue outgoing + int group = all_srp.out_link().size() / k_out; + for (int i = 0; i < k_out; ++i) + { + std::pair range(i*group, (i+1)*group); + srp.enqueue(srp.out_link().target(i), range); + for (int j = i*group; j < (i+1)*group; ++j) + { + int from = srp.gid(); + int to = all_srp.out_link().target(j).gid; + srp.enqueue(srp.out_link().target(i), std::make_pair(from, to)); + srp.enqueue(srp.out_link().target(i), all_queues[all_srp.out_link().target(j)]); + } + } + } else if (k_out == 0) // final round + { + // dequeue incoming + reorder into the correct order + ReduceProxy all_srp(srp, srp.block(), 1, srp.assigner(), all_neighbors_link, empty_link); + + Master::IncomingQueues all_incoming; + all_incoming.swap(*srp.incoming()); + + std::pair range; // all the ranges should be the same + for (int i = 0; i < k_in; ++i) + { + int gid_in = srp.in_link().target(i).gid; + MemoryBuffer& in = all_incoming[gid_in]; + load(in, range); + while(in) + { + std::pair from_to; + load(in, from_to); + load(in, all_srp.incoming(from_to.first)); + all_srp.incoming(from_to.first).reset(); + } + } + + op(b, all_srp); + } else // intermediate round: reshuffle queues + { + // add up buffer sizes + std::vector sizes_out(k_out, sizeof(std::pair)); + std::pair range; // all the ranges should be the same + for (int i = 0; i < k_in; ++i) + { + MemoryBuffer& in = srp.incoming(srp.in_link().target(i).gid); + + load(in, range); + int group = (range.second - range.first)/k_out; + + std::pair from_to; + size_t s; + while(in) + { + diy::load(in, from_to); + diy::load(in, s); + + int j = (from_to.second - range.first) / group; + sizes_out[j] += s + sizeof(size_t) + sizeof(std::pair); + in.skip(s); + } + in.reset(); + } + + // reserve outgoing buffers of correct size + int group = (range.second - range.first)/k_out; + for (int i = 0; i < k_out; ++i) + { + MemoryBuffer& out = srp.outgoing(srp.out_link().target(i)); + out.reserve(sizes_out[i]); + + std::pair out_range; + out_range.first = range.first + group*i; + out_range.second = range.first + group*(i+1); + save(out, out_range); + } + + // re-direct the queues + for (int i = 0; i < k_in; ++i) + { + MemoryBuffer& in = srp.incoming(srp.in_link().target(i).gid); + + std::pair range; + load(in, range); + + std::pair from_to; + while(in) + { + load(in, from_to); + int j = (from_to.second - range.first) / group; + + MemoryBuffer& out = srp.outgoing(srp.out_link().target(j)); + save(out, from_to); + MemoryBuffer::copy(in, out); + } + } + } + } + + const Op& op; + Link all_neighbors_link, empty_link; + }; + + struct SkipIntermediate + { + SkipIntermediate(size_t rounds_): + rounds(rounds_) {} + + bool operator()(int round, int, const Master&) const { if (round == 0 || round == (int) rounds) return false; return true; } + + size_t rounds; + }; +} + +} + +#endif diff --git a/diy/include/diy/detail/traits.hpp b/diy/include/diy/detail/traits.hpp new file mode 100644 index 000000000..f47b733c8 --- /dev/null +++ b/diy/include/diy/detail/traits.hpp @@ -0,0 +1,318 @@ +//-------------------------------------- +// utils/traits: Additional type traits +//-------------------------------------- +// +// Copyright kennytm (auraHT Ltd.) 2011. +// Distributed under the Boost Software License, Version 1.0. +// (See accompanying file doc/LICENSE_1_0.txt or copy at +// http://www.boost.org/LICENSE_1_0.txt) + +/** + +```` --- Additional type traits +================================================= + +This module provides additional type traits and related functions, missing from +the standard library. + +*/ + +#ifndef DIY_UTILS_TRAITS_HPP +#define DIY_UTILS_TRAITS_HPP + +#include +#include +#include +#include + +namespace diy +{ +namespace detail { + +/** +.. macro:: DECLARE_HAS_TYPE_MEMBER(member_name) + + This macro declares a template ``has_member_name`` which will check whether + a type member ``member_name`` exists in a particular type. + + Example:: + + DECLARE_HAS_TYPE_MEMBER(result_type) + + ... + + printf("%d\n", has_result_type< std::plus >::value); + // ^ prints '1' (true) + printf("%d\n", has_result_type< double(*)() >::value); + // ^ prints '0' (false) +*/ +#define DECLARE_HAS_TYPE_MEMBER(member_name) \ + template \ + struct has_##member_name \ + { enum { value = false }; }; \ + template \ + struct has_##member_name::type> \ + { enum { value = true }; }; + +/** +.. type:: struct utils::function_traits + + Obtain compile-time information about a function object *F*. + + This template currently supports the following types: + + * Normal function types (``R(T...)``), function pointers (``R(*)(T...)``) + and function references (``R(&)(T...)`` and ``R(&&)(T...)``). + * Member functions (``R(C::*)(T...)``) + * ``std::function`` + * Type of lambda functions, and any other types that has a unique + ``operator()``. + * Type of ``std::mem_fn`` (only for GCC's libstdc++ and LLVM's libc++). + Following the C++ spec, the first argument will be a raw pointer. +*/ +template +struct function_traits + : public function_traits +{}; + +namespace xx_impl +{ + template + struct memfn_type + { + typedef typename std::conditional< + std::is_const::value, + typename std::conditional< + std::is_volatile::value, + R (C::*)(A...) const volatile, + R (C::*)(A...) const + >::type, + typename std::conditional< + std::is_volatile::value, + R (C::*)(A...) volatile, + R (C::*)(A...) + >::type + >::type type; + }; +} + +template +struct function_traits +{ + /** + .. type:: type result_type + + The type returned by calling an instance of the function object type *F*. + */ + typedef ReturnType result_type; + + /** + .. type:: type function_type + + The function type (``R(T...)``). + */ + typedef ReturnType function_type(Args...); + + /** + .. type:: type member_function_type + + The member function type for an *OwnerType* (``R(OwnerType::*)(T...)``). + */ + template + using member_function_type = typename xx_impl::memfn_type< + typename std::remove_pointer::type>::type, + ReturnType, Args... + >::type; + + /** + .. data:: static const size_t arity + + Number of arguments the function object will take. + */ + enum { arity = sizeof...(Args) }; + + /** + .. type:: type arg::type + + The type of the *n*-th argument. + */ + template + struct arg + { + typedef typename std::tuple_element>::type type; + }; +}; + +template +struct function_traits + : public function_traits +{}; + +template +struct function_traits + : public function_traits +{ + typedef ClassType& owner_type; +}; + +template +struct function_traits + : public function_traits +{ + typedef const ClassType& owner_type; +}; + +template +struct function_traits + : public function_traits +{ + typedef volatile ClassType& owner_type; +}; + +template +struct function_traits + : public function_traits +{ + typedef const volatile ClassType& owner_type; +}; + +template +struct function_traits> + : public function_traits +{}; + +#if defined(_GLIBCXX_FUNCTIONAL) +#define MEM_FN_SYMBOL_XX0SL7G4Z0J std::_Mem_fn +#elif defined(_LIBCPP_FUNCTIONAL) +#define MEM_FN_SYMBOL_XX0SL7G4Z0J std::__mem_fn +#endif + +#ifdef MEM_FN_SYMBOL_XX0SL7G4Z0J + +template +struct function_traits> + : public function_traits +{}; +template +struct function_traits> + : public function_traits +{}; +template +struct function_traits> + : public function_traits +{}; +template +struct function_traits> + : public function_traits +{}; +template +struct function_traits> + : public function_traits +{}; + +#undef MEM_FN_SYMBOL_XX0SL7G4Z0J +#endif + +template +struct function_traits : public function_traits {}; +template +struct function_traits : public function_traits {}; +template +struct function_traits : public function_traits {}; +template +struct function_traits : public function_traits {}; +template +struct function_traits : public function_traits {}; +template +struct function_traits : public function_traits {}; +template +struct function_traits : public function_traits {}; +template +struct function_traits : public function_traits {}; + + +#define FORWARD_RES_8QR485JMSBT \ + typename std::conditional< \ + std::is_lvalue_reference::value, \ + T&, \ + typename std::remove_reference::type&& \ + >::type + +/** +.. function:: auto utils::forward_like(T&& t) noexcept + + Forward the reference *t* like the type of *Like*. That means, if *Like* is + an lvalue (reference), this function will return an lvalue reference of *t*. + Otherwise, if *Like* is an rvalue, this function will return an rvalue + reference of *t*. + + This is mainly used to propagate the expression category (lvalue/rvalue) of + a member of *Like*, generalizing ``std::forward``. +*/ +template +FORWARD_RES_8QR485JMSBT forward_like(T&& input) noexcept +{ + return static_cast(input); +} + +#undef FORWARD_RES_8QR485JMSBT + +/** +.. type:: struct utils::copy_cv + + Copy the CV qualifier between the two types. For example, + ``utils::copy_cv::type`` will become ``const double``. +*/ +template +struct copy_cv +{ +private: + typedef typename std::remove_cv::type raw_To; + typedef typename std::conditional::value, + const raw_To, raw_To>::type const_raw_To; +public: + /** + .. type:: type type + + Result of cv-copying. + */ + typedef typename std::conditional::value, + volatile const_raw_To, const_raw_To>::type type; +}; + +/** +.. type:: struct utils::pointee + + Returns the type by derefering an instance of *T*. This is a generalization + of ``std::remove_pointer``, that it also works with iterators. +*/ +template +struct pointee +{ + /** + .. type:: type type + + Result of dereferencing. + */ + typedef typename std::remove_reference())>::type type; +}; + +/** +.. function:: std::add_rvalue_reference::type utils::rt_val() noexcept + + Returns a value of type *T*. It is guaranteed to do nothing and will not + throw a compile-time error, but using the returned result will cause + undefined behavior. +*/ +template +typename std::add_rvalue_reference::type rt_val() noexcept +{ + return std::move(*static_cast(nullptr)); +} + +} + +} + +#endif + diff --git a/diy/include/diy/fmt/format.cc b/diy/include/diy/fmt/format.cc new file mode 100644 index 000000000..ae5d11034 --- /dev/null +++ b/diy/include/diy/fmt/format.cc @@ -0,0 +1,935 @@ +/* + Formatting library for C++ + + Copyright (c) 2012 - 2016, Victor Zverovich + All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are met: + + 1. Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + 2. Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR + ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + */ + +#include "format.h" + +#include + +#include +#include +#include +#include +#include +#include // for std::ptrdiff_t + +#if defined(_WIN32) && defined(__MINGW32__) +# include +#endif + +#if FMT_USE_WINDOWS_H +# if defined(NOMINMAX) || defined(FMT_WIN_MINMAX) +# include +# else +# define NOMINMAX +# include +# undef NOMINMAX +# endif +#endif + +using fmt::internal::Arg; + +#if FMT_EXCEPTIONS +# define FMT_TRY try +# define FMT_CATCH(x) catch (x) +#else +# define FMT_TRY if (true) +# define FMT_CATCH(x) if (false) +#endif + +#ifdef _MSC_VER +# pragma warning(push) +# pragma warning(disable: 4127) // conditional expression is constant +# pragma warning(disable: 4702) // unreachable code +// Disable deprecation warning for strerror. The latter is not called but +// MSVC fails to detect it. +# pragma warning(disable: 4996) +#endif + +// Dummy implementations of strerror_r and strerror_s called if corresponding +// system functions are not available. +static inline fmt::internal::Null<> strerror_r(int, char *, ...) { + return fmt::internal::Null<>(); +} +static inline fmt::internal::Null<> strerror_s(char *, std::size_t, ...) { + return fmt::internal::Null<>(); +} + +namespace fmt { +namespace { + +#ifndef _MSC_VER +# define FMT_SNPRINTF snprintf +#else // _MSC_VER +inline int fmt_snprintf(char *buffer, size_t size, const char *format, ...) { + va_list args; + va_start(args, format); + int result = vsnprintf_s(buffer, size, _TRUNCATE, format, args); + va_end(args); + return result; +} +# define FMT_SNPRINTF fmt_snprintf +#endif // _MSC_VER + +#if defined(_WIN32) && defined(__MINGW32__) && !defined(__NO_ISOCEXT) +# define FMT_SWPRINTF snwprintf +#else +# define FMT_SWPRINTF swprintf +#endif // defined(_WIN32) && defined(__MINGW32__) && !defined(__NO_ISOCEXT) + +// Checks if a value fits in int - used to avoid warnings about comparing +// signed and unsigned integers. +template +struct IntChecker { + template + static bool fits_in_int(T value) { + unsigned max = INT_MAX; + return value <= max; + } + static bool fits_in_int(bool) { return true; } +}; + +template <> +struct IntChecker { + template + static bool fits_in_int(T value) { + return value >= INT_MIN && value <= INT_MAX; + } + static bool fits_in_int(int) { return true; } +}; + +const char RESET_COLOR[] = "\x1b[0m"; + +typedef void (*FormatFunc)(Writer &, int, StringRef); + +// Portable thread-safe version of strerror. +// Sets buffer to point to a string describing the error code. +// This can be either a pointer to a string stored in buffer, +// or a pointer to some static immutable string. +// Returns one of the following values: +// 0 - success +// ERANGE - buffer is not large enough to store the error message +// other - failure +// Buffer should be at least of size 1. +int safe_strerror( + int error_code, char *&buffer, std::size_t buffer_size) FMT_NOEXCEPT { + FMT_ASSERT(buffer != 0 && buffer_size != 0, "invalid buffer"); + + class StrError { + private: + int error_code_; + char *&buffer_; + std::size_t buffer_size_; + + // A noop assignment operator to avoid bogus warnings. + void operator=(const StrError &) {} + + // Handle the result of XSI-compliant version of strerror_r. + int handle(int result) { + // glibc versions before 2.13 return result in errno. + return result == -1 ? errno : result; + } + + // Handle the result of GNU-specific version of strerror_r. + int handle(char *message) { + // If the buffer is full then the message is probably truncated. + if (message == buffer_ && strlen(buffer_) == buffer_size_ - 1) + return ERANGE; + buffer_ = message; + return 0; + } + + // Handle the case when strerror_r is not available. + int handle(internal::Null<>) { + return fallback(strerror_s(buffer_, buffer_size_, error_code_)); + } + + // Fallback to strerror_s when strerror_r is not available. + int fallback(int result) { + // If the buffer is full then the message is probably truncated. + return result == 0 && strlen(buffer_) == buffer_size_ - 1 ? + ERANGE : result; + } + + // Fallback to strerror if strerror_r and strerror_s are not available. + int fallback(internal::Null<>) { + errno = 0; + buffer_ = strerror(error_code_); + return errno; + } + + public: + StrError(int err_code, char *&buf, std::size_t buf_size) + : error_code_(err_code), buffer_(buf), buffer_size_(buf_size) {} + + int run() { + strerror_r(0, 0, ""); // Suppress a warning about unused strerror_r. + return handle(strerror_r(error_code_, buffer_, buffer_size_)); + } + }; + return StrError(error_code, buffer, buffer_size).run(); +} + +void format_error_code(Writer &out, int error_code, + StringRef message) FMT_NOEXCEPT { + // Report error code making sure that the output fits into + // INLINE_BUFFER_SIZE to avoid dynamic memory allocation and potential + // bad_alloc. + out.clear(); + static const char SEP[] = ": "; + static const char ERROR_STR[] = "error "; + // Subtract 2 to account for terminating null characters in SEP and ERROR_STR. + std::size_t error_code_size = sizeof(SEP) + sizeof(ERROR_STR) - 2; + typedef internal::IntTraits::MainType MainType; + MainType abs_value = static_cast(error_code); + if (internal::is_negative(error_code)) { + abs_value = 0 - abs_value; + ++error_code_size; + } + error_code_size += internal::count_digits(abs_value); + if (message.size() <= internal::INLINE_BUFFER_SIZE - error_code_size) + out << message << SEP; + out << ERROR_STR << error_code; + assert(out.size() <= internal::INLINE_BUFFER_SIZE); +} + +void report_error(FormatFunc func, int error_code, + StringRef message) FMT_NOEXCEPT { + MemoryWriter full_message; + func(full_message, error_code, message); + // Use Writer::data instead of Writer::c_str to avoid potential memory + // allocation. + std::fwrite(full_message.data(), full_message.size(), 1, stderr); + std::fputc('\n', stderr); +} + +// IsZeroInt::visit(arg) returns true iff arg is a zero integer. +class IsZeroInt : public ArgVisitor { + public: + template + bool visit_any_int(T value) { return value == 0; } +}; + +// Checks if an argument is a valid printf width specifier and sets +// left alignment if it is negative. +class WidthHandler : public ArgVisitor { + private: + FormatSpec &spec_; + + FMT_DISALLOW_COPY_AND_ASSIGN(WidthHandler); + + public: + explicit WidthHandler(FormatSpec &spec) : spec_(spec) {} + + void report_unhandled_arg() { + FMT_THROW(FormatError("width is not integer")); + } + + template + unsigned visit_any_int(T value) { + typedef typename internal::IntTraits::MainType UnsignedType; + UnsignedType width = static_cast(value); + if (internal::is_negative(value)) { + spec_.align_ = ALIGN_LEFT; + width = 0 - width; + } + if (width > INT_MAX) + FMT_THROW(FormatError("number is too big")); + return static_cast(width); + } +}; + +class PrecisionHandler : public ArgVisitor { + public: + void report_unhandled_arg() { + FMT_THROW(FormatError("precision is not integer")); + } + + template + int visit_any_int(T value) { + if (!IntChecker::is_signed>::fits_in_int(value)) + FMT_THROW(FormatError("number is too big")); + return static_cast(value); + } +}; + +template +struct is_same { + enum { value = 0 }; +}; + +template +struct is_same { + enum { value = 1 }; +}; + +// An argument visitor that converts an integer argument to T for printf, +// if T is an integral type. If T is void, the argument is converted to +// corresponding signed or unsigned type depending on the type specifier: +// 'd' and 'i' - signed, other - unsigned) +template +class ArgConverter : public ArgVisitor, void> { + private: + internal::Arg &arg_; + wchar_t type_; + + FMT_DISALLOW_COPY_AND_ASSIGN(ArgConverter); + + public: + ArgConverter(internal::Arg &arg, wchar_t type) + : arg_(arg), type_(type) {} + + void visit_bool(bool value) { + if (type_ != 's') + visit_any_int(value); + } + + template + void visit_any_int(U value) { + bool is_signed = type_ == 'd' || type_ == 'i'; + using internal::Arg; + typedef typename internal::Conditional< + is_same::value, U, T>::type TargetType; + if (sizeof(TargetType) <= sizeof(int)) { + // Extra casts are used to silence warnings. + if (is_signed) { + arg_.type = Arg::INT; + arg_.int_value = static_cast(static_cast(value)); + } else { + arg_.type = Arg::UINT; + typedef typename internal::MakeUnsigned::Type Unsigned; + arg_.uint_value = static_cast(static_cast(value)); + } + } else { + if (is_signed) { + arg_.type = Arg::LONG_LONG; + // glibc's printf doesn't sign extend arguments of smaller types: + // std::printf("%lld", -42); // prints "4294967254" + // but we don't have to do the same because it's a UB. + arg_.long_long_value = static_cast(value); + } else { + arg_.type = Arg::ULONG_LONG; + arg_.ulong_long_value = + static_cast::Type>(value); + } + } + } +}; + +// Converts an integer argument to char for printf. +class CharConverter : public ArgVisitor { + private: + internal::Arg &arg_; + + FMT_DISALLOW_COPY_AND_ASSIGN(CharConverter); + + public: + explicit CharConverter(internal::Arg &arg) : arg_(arg) {} + + template + void visit_any_int(T value) { + arg_.type = internal::Arg::CHAR; + arg_.int_value = static_cast(value); + } +}; +} // namespace + +namespace internal { + +template +class PrintfArgFormatter : + public ArgFormatterBase, Char> { + + void write_null_pointer() { + this->spec().type_ = 0; + this->write("(nil)"); + } + + typedef ArgFormatterBase, Char> Base; + + public: + PrintfArgFormatter(BasicWriter &w, FormatSpec &s) + : ArgFormatterBase, Char>(w, s) {} + + void visit_bool(bool value) { + FormatSpec &fmt_spec = this->spec(); + if (fmt_spec.type_ != 's') + return this->visit_any_int(value); + fmt_spec.type_ = 0; + this->write(value); + } + + void visit_char(int value) { + const FormatSpec &fmt_spec = this->spec(); + BasicWriter &w = this->writer(); + if (fmt_spec.type_ && fmt_spec.type_ != 'c') + w.write_int(value, fmt_spec); + typedef typename BasicWriter::CharPtr CharPtr; + CharPtr out = CharPtr(); + if (fmt_spec.width_ > 1) { + Char fill = ' '; + out = w.grow_buffer(fmt_spec.width_); + if (fmt_spec.align_ != ALIGN_LEFT) { + std::fill_n(out, fmt_spec.width_ - 1, fill); + out += fmt_spec.width_ - 1; + } else { + std::fill_n(out + 1, fmt_spec.width_ - 1, fill); + } + } else { + out = w.grow_buffer(1); + } + *out = static_cast(value); + } + + void visit_cstring(const char *value) { + if (value) + Base::visit_cstring(value); + else if (this->spec().type_ == 'p') + write_null_pointer(); + else + this->write("(null)"); + } + + void visit_pointer(const void *value) { + if (value) + return Base::visit_pointer(value); + this->spec().type_ = 0; + write_null_pointer(); + } + + void visit_custom(Arg::CustomValue c) { + BasicFormatter formatter(ArgList(), this->writer()); + const Char format_str[] = {'}', 0}; + const Char *format = format_str; + c.format(&formatter, c.value, &format); + } +}; +} // namespace internal +} // namespace fmt + +FMT_FUNC void fmt::SystemError::init( + int err_code, CStringRef format_str, ArgList args) { + error_code_ = err_code; + MemoryWriter w; + internal::format_system_error(w, err_code, format(format_str, args)); + std::runtime_error &base = *this; + base = std::runtime_error(w.str()); +} + +template +int fmt::internal::CharTraits::format_float( + char *buffer, std::size_t size, const char *format, + unsigned width, int precision, T value) { + if (width == 0) { + return precision < 0 ? + FMT_SNPRINTF(buffer, size, format, value) : + FMT_SNPRINTF(buffer, size, format, precision, value); + } + return precision < 0 ? + FMT_SNPRINTF(buffer, size, format, width, value) : + FMT_SNPRINTF(buffer, size, format, width, precision, value); +} + +template +int fmt::internal::CharTraits::format_float( + wchar_t *buffer, std::size_t size, const wchar_t *format, + unsigned width, int precision, T value) { + if (width == 0) { + return precision < 0 ? + FMT_SWPRINTF(buffer, size, format, value) : + FMT_SWPRINTF(buffer, size, format, precision, value); + } + return precision < 0 ? + FMT_SWPRINTF(buffer, size, format, width, value) : + FMT_SWPRINTF(buffer, size, format, width, precision, value); +} + +template +const char fmt::internal::BasicData::DIGITS[] = + "0001020304050607080910111213141516171819" + "2021222324252627282930313233343536373839" + "4041424344454647484950515253545556575859" + "6061626364656667686970717273747576777879" + "8081828384858687888990919293949596979899"; + +#define FMT_POWERS_OF_10(factor) \ + factor * 10, \ + factor * 100, \ + factor * 1000, \ + factor * 10000, \ + factor * 100000, \ + factor * 1000000, \ + factor * 10000000, \ + factor * 100000000, \ + factor * 1000000000 + +template +const uint32_t fmt::internal::BasicData::POWERS_OF_10_32[] = { + 0, FMT_POWERS_OF_10(1) +}; + +template +const uint64_t fmt::internal::BasicData::POWERS_OF_10_64[] = { + 0, + FMT_POWERS_OF_10(1), + FMT_POWERS_OF_10(fmt::ULongLong(1000000000)), + // Multiply several constants instead of using a single long long constant + // to avoid warnings about C++98 not supporting long long. + fmt::ULongLong(1000000000) * fmt::ULongLong(1000000000) * 10 +}; + +FMT_FUNC void fmt::internal::report_unknown_type(char code, const char *type) { + (void)type; + if (std::isprint(static_cast(code))) { + FMT_THROW(fmt::FormatError( + fmt::format("unknown format code '{}' for {}", code, type))); + } + FMT_THROW(fmt::FormatError( + fmt::format("unknown format code '\\x{:02x}' for {}", + static_cast(code), type))); +} + +#if FMT_USE_WINDOWS_H + +FMT_FUNC fmt::internal::UTF8ToUTF16::UTF8ToUTF16(fmt::StringRef s) { + static const char ERROR_MSG[] = "cannot convert string from UTF-8 to UTF-16"; + if (s.size() > INT_MAX) + FMT_THROW(WindowsError(ERROR_INVALID_PARAMETER, ERROR_MSG)); + int s_size = static_cast(s.size()); + int length = MultiByteToWideChar( + CP_UTF8, MB_ERR_INVALID_CHARS, s.data(), s_size, 0, 0); + if (length == 0) + FMT_THROW(WindowsError(GetLastError(), ERROR_MSG)); + buffer_.resize(length + 1); + length = MultiByteToWideChar( + CP_UTF8, MB_ERR_INVALID_CHARS, s.data(), s_size, &buffer_[0], length); + if (length == 0) + FMT_THROW(WindowsError(GetLastError(), ERROR_MSG)); + buffer_[length] = 0; +} + +FMT_FUNC fmt::internal::UTF16ToUTF8::UTF16ToUTF8(fmt::WStringRef s) { + if (int error_code = convert(s)) { + FMT_THROW(WindowsError(error_code, + "cannot convert string from UTF-16 to UTF-8")); + } +} + +FMT_FUNC int fmt::internal::UTF16ToUTF8::convert(fmt::WStringRef s) { + if (s.size() > INT_MAX) + return ERROR_INVALID_PARAMETER; + int s_size = static_cast(s.size()); + int length = WideCharToMultiByte(CP_UTF8, 0, s.data(), s_size, 0, 0, 0, 0); + if (length == 0) + return GetLastError(); + buffer_.resize(length + 1); + length = WideCharToMultiByte( + CP_UTF8, 0, s.data(), s_size, &buffer_[0], length, 0, 0); + if (length == 0) + return GetLastError(); + buffer_[length] = 0; + return 0; +} + +FMT_FUNC void fmt::WindowsError::init( + int err_code, CStringRef format_str, ArgList args) { + error_code_ = err_code; + MemoryWriter w; + internal::format_windows_error(w, err_code, format(format_str, args)); + std::runtime_error &base = *this; + base = std::runtime_error(w.str()); +} + +FMT_FUNC void fmt::internal::format_windows_error( + fmt::Writer &out, int error_code, + fmt::StringRef message) FMT_NOEXCEPT { + FMT_TRY { + MemoryBuffer buffer; + buffer.resize(INLINE_BUFFER_SIZE); + for (;;) { + wchar_t *system_message = &buffer[0]; + int result = FormatMessageW(FORMAT_MESSAGE_FROM_SYSTEM | FORMAT_MESSAGE_IGNORE_INSERTS, + 0, error_code, MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT), + system_message, static_cast(buffer.size()), 0); + if (result != 0) { + UTF16ToUTF8 utf8_message; + if (utf8_message.convert(system_message) == ERROR_SUCCESS) { + out << message << ": " << utf8_message; + return; + } + break; + } + if (GetLastError() != ERROR_INSUFFICIENT_BUFFER) + break; // Can't get error message, report error code instead. + buffer.resize(buffer.size() * 2); + } + } FMT_CATCH(...) {} + fmt::format_error_code(out, error_code, message); // 'fmt::' is for bcc32. +} + +#endif // FMT_USE_WINDOWS_H + +FMT_FUNC void fmt::internal::format_system_error( + fmt::Writer &out, int error_code, + fmt::StringRef message) FMT_NOEXCEPT { + FMT_TRY { + MemoryBuffer buffer; + buffer.resize(INLINE_BUFFER_SIZE); + for (;;) { + char *system_message = &buffer[0]; + int result = safe_strerror(error_code, system_message, buffer.size()); + if (result == 0) { + out << message << ": " << system_message; + return; + } + if (result != ERANGE) + break; // Can't get error message, report error code instead. + buffer.resize(buffer.size() * 2); + } + } FMT_CATCH(...) {} + fmt::format_error_code(out, error_code, message); // 'fmt::' is for bcc32. +} + +template +void fmt::internal::ArgMap::init(const ArgList &args) { + if (!map_.empty()) + return; + typedef internal::NamedArg NamedArg; + const NamedArg *named_arg = 0; + bool use_values = + args.type(ArgList::MAX_PACKED_ARGS - 1) == internal::Arg::NONE; + if (use_values) { + for (unsigned i = 0;/*nothing*/; ++i) { + internal::Arg::Type arg_type = args.type(i); + switch (arg_type) { + case internal::Arg::NONE: + return; + case internal::Arg::NAMED_ARG: + named_arg = static_cast(args.values_[i].pointer); + map_.push_back(Pair(named_arg->name, *named_arg)); + break; + default: + /*nothing*/; + } + } + return; + } + for (unsigned i = 0; i != ArgList::MAX_PACKED_ARGS; ++i) { + internal::Arg::Type arg_type = args.type(i); + if (arg_type == internal::Arg::NAMED_ARG) { + named_arg = static_cast(args.args_[i].pointer); + map_.push_back(Pair(named_arg->name, *named_arg)); + } + } + for (unsigned i = ArgList::MAX_PACKED_ARGS;/*nothing*/; ++i) { + switch (args.args_[i].type) { + case internal::Arg::NONE: + return; + case internal::Arg::NAMED_ARG: + named_arg = static_cast(args.args_[i].pointer); + map_.push_back(Pair(named_arg->name, *named_arg)); + break; + default: + /*nothing*/; + } + } +} + +template +void fmt::internal::FixedBuffer::grow(std::size_t) { + FMT_THROW(std::runtime_error("buffer overflow")); +} + +FMT_FUNC Arg fmt::internal::FormatterBase::do_get_arg( + unsigned arg_index, const char *&error) { + Arg arg = args_[arg_index]; + switch (arg.type) { + case Arg::NONE: + error = "argument index out of range"; + break; + case Arg::NAMED_ARG: + arg = *static_cast(arg.pointer); + break; + default: + /*nothing*/; + } + return arg; +} + +template +void fmt::internal::PrintfFormatter::parse_flags( + FormatSpec &spec, const Char *&s) { + for (;;) { + switch (*s++) { + case '-': + spec.align_ = ALIGN_LEFT; + break; + case '+': + spec.flags_ |= SIGN_FLAG | PLUS_FLAG; + break; + case '0': + spec.fill_ = '0'; + break; + case ' ': + spec.flags_ |= SIGN_FLAG; + break; + case '#': + spec.flags_ |= HASH_FLAG; + break; + default: + --s; + return; + } + } +} + +template +Arg fmt::internal::PrintfFormatter::get_arg( + const Char *s, unsigned arg_index) { + (void)s; + const char *error = 0; + Arg arg = arg_index == UINT_MAX ? + next_arg(error) : FormatterBase::get_arg(arg_index - 1, error); + if (error) + FMT_THROW(FormatError(!*s ? "invalid format string" : error)); + return arg; +} + +template +unsigned fmt::internal::PrintfFormatter::parse_header( + const Char *&s, FormatSpec &spec) { + unsigned arg_index = UINT_MAX; + Char c = *s; + if (c >= '0' && c <= '9') { + // Parse an argument index (if followed by '$') or a width possibly + // preceded with '0' flag(s). + unsigned value = parse_nonnegative_int(s); + if (*s == '$') { // value is an argument index + ++s; + arg_index = value; + } else { + if (c == '0') + spec.fill_ = '0'; + if (value != 0) { + // Nonzero value means that we parsed width and don't need to + // parse it or flags again, so return now. + spec.width_ = value; + return arg_index; + } + } + } + parse_flags(spec, s); + // Parse width. + if (*s >= '0' && *s <= '9') { + spec.width_ = parse_nonnegative_int(s); + } else if (*s == '*') { + ++s; + spec.width_ = WidthHandler(spec).visit(get_arg(s)); + } + return arg_index; +} + +template +void fmt::internal::PrintfFormatter::format( + BasicWriter &writer, BasicCStringRef format_str) { + const Char *start = format_str.c_str(); + const Char *s = start; + while (*s) { + Char c = *s++; + if (c != '%') continue; + if (*s == c) { + write(writer, start, s); + start = ++s; + continue; + } + write(writer, start, s - 1); + + FormatSpec spec; + spec.align_ = ALIGN_RIGHT; + + // Parse argument index, flags and width. + unsigned arg_index = parse_header(s, spec); + + // Parse precision. + if (*s == '.') { + ++s; + if ('0' <= *s && *s <= '9') { + spec.precision_ = static_cast(parse_nonnegative_int(s)); + } else if (*s == '*') { + ++s; + spec.precision_ = PrecisionHandler().visit(get_arg(s)); + } + } + + Arg arg = get_arg(s, arg_index); + if (spec.flag(HASH_FLAG) && IsZeroInt().visit(arg)) + spec.flags_ &= ~to_unsigned(HASH_FLAG); + if (spec.fill_ == '0') { + if (arg.type <= Arg::LAST_NUMERIC_TYPE) + spec.align_ = ALIGN_NUMERIC; + else + spec.fill_ = ' '; // Ignore '0' flag for non-numeric types. + } + + // Parse length and convert the argument to the required type. + switch (*s++) { + case 'h': + if (*s == 'h') + ArgConverter(arg, *++s).visit(arg); + else + ArgConverter(arg, *s).visit(arg); + break; + case 'l': + if (*s == 'l') + ArgConverter(arg, *++s).visit(arg); + else + ArgConverter(arg, *s).visit(arg); + break; + case 'j': + ArgConverter(arg, *s).visit(arg); + break; + case 'z': + ArgConverter(arg, *s).visit(arg); + break; + case 't': + ArgConverter(arg, *s).visit(arg); + break; + case 'L': + // printf produces garbage when 'L' is omitted for long double, no + // need to do the same. + break; + default: + --s; + ArgConverter(arg, *s).visit(arg); + } + + // Parse type. + if (!*s) + FMT_THROW(FormatError("invalid format string")); + spec.type_ = static_cast(*s++); + if (arg.type <= Arg::LAST_INTEGER_TYPE) { + // Normalize type. + switch (spec.type_) { + case 'i': case 'u': + spec.type_ = 'd'; + break; + case 'c': + // TODO: handle wchar_t + CharConverter(arg).visit(arg); + break; + } + } + + start = s; + + // Format argument. + internal::PrintfArgFormatter(writer, spec).visit(arg); + } + write(writer, start, s); +} + +FMT_FUNC void fmt::report_system_error( + int error_code, fmt::StringRef message) FMT_NOEXCEPT { + // 'fmt::' is for bcc32. + fmt::report_error(internal::format_system_error, error_code, message); +} + +#if FMT_USE_WINDOWS_H +FMT_FUNC void fmt::report_windows_error( + int error_code, fmt::StringRef message) FMT_NOEXCEPT { + // 'fmt::' is for bcc32. + fmt::report_error(internal::format_windows_error, error_code, message); +} +#endif + +FMT_FUNC void fmt::print(std::FILE *f, CStringRef format_str, ArgList args) { + MemoryWriter w; + w.write(format_str, args); + std::fwrite(w.data(), 1, w.size(), f); +} + +FMT_FUNC void fmt::print(CStringRef format_str, ArgList args) { + print(stdout, format_str, args); +} + +FMT_FUNC void fmt::print_colored(Color c, CStringRef format, ArgList args) { + char escape[] = "\x1b[30m"; + escape[3] = static_cast('0' + c); + std::fputs(escape, stdout); + print(format, args); + std::fputs(RESET_COLOR, stdout); +} + +FMT_FUNC int fmt::fprintf(std::FILE *f, CStringRef format, ArgList args) { + MemoryWriter w; + printf(w, format, args); + std::size_t size = w.size(); + return std::fwrite(w.data(), 1, size, f) < size ? -1 : static_cast(size); +} + +#ifndef FMT_HEADER_ONLY + +template struct fmt::internal::BasicData; + +// Explicit instantiations for char. + +template void fmt::internal::FixedBuffer::grow(std::size_t); + +template void fmt::internal::ArgMap::init(const fmt::ArgList &args); + +template void fmt::internal::PrintfFormatter::format( + BasicWriter &writer, CStringRef format); + +template int fmt::internal::CharTraits::format_float( + char *buffer, std::size_t size, const char *format, + unsigned width, int precision, double value); + +template int fmt::internal::CharTraits::format_float( + char *buffer, std::size_t size, const char *format, + unsigned width, int precision, long double value); + +// Explicit instantiations for wchar_t. + +template void fmt::internal::FixedBuffer::grow(std::size_t); + +template void fmt::internal::ArgMap::init(const fmt::ArgList &args); + +template void fmt::internal::PrintfFormatter::format( + BasicWriter &writer, WCStringRef format); + +template int fmt::internal::CharTraits::format_float( + wchar_t *buffer, std::size_t size, const wchar_t *format, + unsigned width, int precision, double value); + +template int fmt::internal::CharTraits::format_float( + wchar_t *buffer, std::size_t size, const wchar_t *format, + unsigned width, int precision, long double value); + +#endif // FMT_HEADER_ONLY + +#ifdef _MSC_VER +# pragma warning(pop) +#endif diff --git a/diy/include/diy/fmt/format.h b/diy/include/diy/fmt/format.h new file mode 100644 index 000000000..0ca1576b8 --- /dev/null +++ b/diy/include/diy/fmt/format.h @@ -0,0 +1,3834 @@ +/* + Formatting library for C++ + + Copyright (c) 2012 - 2016, Victor Zverovich + All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are met: + + 1. Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + 2. Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR + ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + */ + +#ifndef FMT_FORMAT_H_ +#define FMT_FORMAT_H_ + +#define FMT_HEADER_ONLY // Added by diy for header-only usage + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#ifdef _SECURE_SCL +# define FMT_SECURE_SCL _SECURE_SCL +#else +# define FMT_SECURE_SCL 0 +#endif + +#if FMT_SECURE_SCL +# include +#endif + +#if defined(_MSC_VER) && _MSC_VER <= 1500 +typedef unsigned __int32 uint32_t; +typedef unsigned __int64 uint64_t; +typedef __int64 intmax_t; +#else +#include +#endif + +#if !defined(FMT_HEADER_ONLY) && defined(_WIN32) +# ifdef FMT_EXPORT +# define FMT_API __declspec(dllexport) +# elif defined(FMT_SHARED) +# define FMT_API __declspec(dllimport) +# endif +#endif +#ifndef FMT_API +# define FMT_API +#endif + +#ifdef __GNUC__ +# define FMT_GCC_VERSION (__GNUC__ * 100 + __GNUC_MINOR__) +# define FMT_GCC_EXTENSION __extension__ +# if FMT_GCC_VERSION >= 406 +# pragma GCC diagnostic push +// Disable the warning about "long long" which is sometimes reported even +// when using __extension__. +# pragma GCC diagnostic ignored "-Wlong-long" +// Disable the warning about declaration shadowing because it affects too +// many valid cases. +# pragma GCC diagnostic ignored "-Wshadow" +// Disable the warning about implicit conversions that may change the sign of +// an integer; silencing it otherwise would require many explicit casts. +# pragma GCC diagnostic ignored "-Wsign-conversion" +# endif +# if __cplusplus >= 201103L || defined __GXX_EXPERIMENTAL_CXX0X__ +# define FMT_HAS_GXX_CXX11 1 +# endif +#else +# define FMT_GCC_EXTENSION +#endif + +#if defined(__INTEL_COMPILER) +# define FMT_ICC_VERSION __INTEL_COMPILER +#elif defined(__ICL) +# define FMT_ICC_VERSION __ICL +#endif + +#if defined(__clang__) && !defined(FMT_ICC_VERSION) +# pragma clang diagnostic push +# pragma clang diagnostic ignored "-Wdocumentation" +#endif + +#ifdef __GNUC_LIBSTD__ +# define FMT_GNUC_LIBSTD_VERSION (__GNUC_LIBSTD__ * 100 + __GNUC_LIBSTD_MINOR__) +#endif + +#ifdef __has_feature +# define FMT_HAS_FEATURE(x) __has_feature(x) +#else +# define FMT_HAS_FEATURE(x) 0 +#endif + +#ifdef __has_builtin +# define FMT_HAS_BUILTIN(x) __has_builtin(x) +#else +# define FMT_HAS_BUILTIN(x) 0 +#endif + +#ifdef __has_cpp_attribute +# define FMT_HAS_CPP_ATTRIBUTE(x) __has_cpp_attribute(x) +#else +# define FMT_HAS_CPP_ATTRIBUTE(x) 0 +#endif + +#ifndef FMT_USE_VARIADIC_TEMPLATES +// Variadic templates are available in GCC since version 4.4 +// (http://gcc.gnu.org/projects/cxx0x.html) and in Visual C++ +// since version 2013. +# define FMT_USE_VARIADIC_TEMPLATES \ + (FMT_HAS_FEATURE(cxx_variadic_templates) || \ + (FMT_GCC_VERSION >= 404 && FMT_HAS_GXX_CXX11) || _MSC_VER >= 1800) +#endif + +#ifndef FMT_USE_RVALUE_REFERENCES +// Don't use rvalue references when compiling with clang and an old libstdc++ +// as the latter doesn't provide std::move. +# if defined(FMT_GNUC_LIBSTD_VERSION) && FMT_GNUC_LIBSTD_VERSION <= 402 +# define FMT_USE_RVALUE_REFERENCES 0 +# else +# define FMT_USE_RVALUE_REFERENCES \ + (FMT_HAS_FEATURE(cxx_rvalue_references) || \ + (FMT_GCC_VERSION >= 403 && FMT_HAS_GXX_CXX11) || _MSC_VER >= 1600) +# endif +#endif + +#if FMT_USE_RVALUE_REFERENCES +# include // for std::move +#endif + +// Check if exceptions are disabled. +#if defined(__GNUC__) && !defined(__EXCEPTIONS) +# define FMT_EXCEPTIONS 0 +#endif +#if defined(_MSC_VER) && !_HAS_EXCEPTIONS +# define FMT_EXCEPTIONS 0 +#endif +#ifndef FMT_EXCEPTIONS +# define FMT_EXCEPTIONS 1 +#endif + +#ifndef FMT_THROW +# if FMT_EXCEPTIONS +# define FMT_THROW(x) throw x +# else +# define FMT_THROW(x) assert(false) +# endif +#endif + +// Define FMT_USE_NOEXCEPT to make fmt use noexcept (C++11 feature). +#ifndef FMT_USE_NOEXCEPT +# define FMT_USE_NOEXCEPT 0 +#endif + +#ifndef FMT_NOEXCEPT +# if FMT_EXCEPTIONS +# if FMT_USE_NOEXCEPT || FMT_HAS_FEATURE(cxx_noexcept) || \ + (FMT_GCC_VERSION >= 408 && FMT_HAS_GXX_CXX11) || \ + _MSC_VER >= 1900 +# define FMT_NOEXCEPT noexcept +# else +# define FMT_NOEXCEPT throw() +# endif +# else +# define FMT_NOEXCEPT +# endif +#endif + +// A macro to disallow the copy constructor and operator= functions +// This should be used in the private: declarations for a class +#ifndef FMT_USE_DELETED_FUNCTIONS +# define FMT_USE_DELETED_FUNCTIONS 0 +#endif + +#if FMT_USE_DELETED_FUNCTIONS || FMT_HAS_FEATURE(cxx_deleted_functions) || \ + (FMT_GCC_VERSION >= 404 && FMT_HAS_GXX_CXX11) || _MSC_VER >= 1800 +# define FMT_DELETED_OR_UNDEFINED = delete +# define FMT_DISALLOW_COPY_AND_ASSIGN(TypeName) \ + TypeName(const TypeName&) = delete; \ + TypeName& operator=(const TypeName&) = delete +#else +# define FMT_DELETED_OR_UNDEFINED +# define FMT_DISALLOW_COPY_AND_ASSIGN(TypeName) \ + TypeName(const TypeName&); \ + TypeName& operator=(const TypeName&) +#endif + +#ifndef FMT_USE_USER_DEFINED_LITERALS +// All compilers which support UDLs also support variadic templates. This +// makes the fmt::literals implementation easier. However, an explicit check +// for variadic templates is added here just in case. +// For Intel's compiler both it and the system gcc/msc must support UDLs. +# define FMT_USE_USER_DEFINED_LITERALS \ + FMT_USE_VARIADIC_TEMPLATES && FMT_USE_RVALUE_REFERENCES && \ + (FMT_HAS_FEATURE(cxx_user_literals) || \ + (FMT_GCC_VERSION >= 407 && FMT_HAS_GXX_CXX11) || _MSC_VER >= 1900) && \ + (!defined(FMT_ICC_VERSION) || FMT_ICC_VERSION >= 1500) +#endif + +#ifndef FMT_ASSERT +# define FMT_ASSERT(condition, message) assert((condition) && message) +#endif + + +#if FMT_GCC_VERSION >= 400 || FMT_HAS_BUILTIN(__builtin_clz) +# define FMT_BUILTIN_CLZ(n) __builtin_clz(n) +#endif + +#if FMT_GCC_VERSION >= 400 || FMT_HAS_BUILTIN(__builtin_clzll) +# define FMT_BUILTIN_CLZLL(n) __builtin_clzll(n) +#endif + +// Some compilers masquerade as both MSVC and GCC-likes or +// otherwise support __builtin_clz and __builtin_clzll, so +// only define FMT_BUILTIN_CLZ using the MSVC intrinsics +// if the clz and clzll builtins are not available. +#if defined(_MSC_VER) && !defined(FMT_BUILTIN_CLZLL) +# include // _BitScanReverse, _BitScanReverse64 + +namespace fmt { +namespace internal { +# pragma intrinsic(_BitScanReverse) +inline uint32_t clz(uint32_t x) { + unsigned long r = 0; + _BitScanReverse(&r, x); + + assert(x != 0); + // Static analysis complains about using uninitialized data + // "r", but the only way that can happen is if "x" is 0, + // which the callers guarantee to not happen. +# pragma warning(suppress: 6102) + return 31 - r; +} +# define FMT_BUILTIN_CLZ(n) fmt::internal::clz(n) + +# ifdef _WIN64 +# pragma intrinsic(_BitScanReverse64) +# endif + +inline uint32_t clzll(uint64_t x) { + unsigned long r = 0; +# ifdef _WIN64 + _BitScanReverse64(&r, x); +# else + // Scan the high 32 bits. + if (_BitScanReverse(&r, static_cast(x >> 32))) + return 63 - (r + 32); + + // Scan the low 32 bits. + _BitScanReverse(&r, static_cast(x)); +# endif + + assert(x != 0); + // Static analysis complains about using uninitialized data + // "r", but the only way that can happen is if "x" is 0, + // which the callers guarantee to not happen. +# pragma warning(suppress: 6102) + return 63 - r; +} +# define FMT_BUILTIN_CLZLL(n) fmt::internal::clzll(n) +} +} +#endif + +namespace fmt { +namespace internal { +struct DummyInt { + int data[2]; + operator int() const { return 0; } +}; +typedef std::numeric_limits FPUtil; + +// Dummy implementations of system functions such as signbit and ecvt called +// if the latter are not available. +inline DummyInt signbit(...) { return DummyInt(); } +inline DummyInt _ecvt_s(...) { return DummyInt(); } +inline DummyInt isinf(...) { return DummyInt(); } +inline DummyInt _finite(...) { return DummyInt(); } +inline DummyInt isnan(...) { return DummyInt(); } +inline DummyInt _isnan(...) { return DummyInt(); } + +// A helper function to suppress bogus "conditional expression is constant" +// warnings. +template +inline T check(T value) { return value; } +} +} // namespace fmt + +namespace std { +// Standard permits specialization of std::numeric_limits. This specialization +// is used to resolve ambiguity between isinf and std::isinf in glibc: +// https://gcc.gnu.org/bugzilla/show_bug.cgi?id=48891 +// and the same for isnan and signbit. +template <> +class numeric_limits : + public std::numeric_limits { + public: + // Portable version of isinf. + template + static bool isinfinity(T x) { + using namespace fmt::internal; + // The resolution "priority" is: + // isinf macro > std::isinf > ::isinf > fmt::internal::isinf + if (check(sizeof(isinf(x)) == sizeof(bool) || + sizeof(isinf(x)) == sizeof(int))) { + return isinf(x) != 0; + } + return !_finite(static_cast(x)); + } + + // Portable version of isnan. + template + static bool isnotanumber(T x) { + using namespace fmt::internal; + if (check(sizeof(isnan(x)) == sizeof(bool) || + sizeof(isnan(x)) == sizeof(int))) { + return isnan(x) != 0; + } + return _isnan(static_cast(x)) != 0; + } + + // Portable version of signbit. + static bool isnegative(double x) { + using namespace fmt::internal; + if (check(sizeof(signbit(x)) == sizeof(int))) + return signbit(x) != 0; + if (x < 0) return true; + if (!isnotanumber(x)) return false; + int dec = 0, sign = 0; + char buffer[2]; // The buffer size must be >= 2 or _ecvt_s will fail. + _ecvt_s(buffer, sizeof(buffer), x, 0, &dec, &sign); + return sign != 0; + } +}; +} // namespace std + +namespace fmt { + +// Fix the warning about long long on older versions of GCC +// that don't support the diagnostic pragma. +FMT_GCC_EXTENSION typedef long long LongLong; +FMT_GCC_EXTENSION typedef unsigned long long ULongLong; + +#if FMT_USE_RVALUE_REFERENCES +using std::move; +#endif + +template +class BasicWriter; + +typedef BasicWriter Writer; +typedef BasicWriter WWriter; + +template +class ArgFormatter; + +template > +class BasicFormatter; + +/** + \rst + A string reference. It can be constructed from a C string or ``std::string``. + + You can use one of the following typedefs for common character types: + + +------------+-------------------------+ + | Type | Definition | + +============+=========================+ + | StringRef | BasicStringRef | + +------------+-------------------------+ + | WStringRef | BasicStringRef | + +------------+-------------------------+ + + This class is most useful as a parameter type to allow passing + different types of strings to a function, for example:: + + template + std::string format(StringRef format_str, const Args & ... args); + + format("{}", 42); + format(std::string("{}"), 42); + \endrst + */ +template +class BasicStringRef { + private: + const Char *data_; + std::size_t size_; + + public: + /** Constructs a string reference object from a C string and a size. */ + BasicStringRef(const Char *s, std::size_t size) : data_(s), size_(size) {} + + /** + \rst + Constructs a string reference object from a C string computing + the size with ``std::char_traits::length``. + \endrst + */ + BasicStringRef(const Char *s) + : data_(s), size_(std::char_traits::length(s)) {} + + /** + \rst + Constructs a string reference from an ``std::string`` object. + \endrst + */ + BasicStringRef(const std::basic_string &s) + : data_(s.c_str()), size_(s.size()) {} + + /** + \rst + Converts a string reference to an ``std::string`` object. + \endrst + */ + std::basic_string to_string() const { + return std::basic_string(data_, size_); + } + + /** Returns a pointer to the string data. */ + const Char *data() const { return data_; } + + /** Returns the string size. */ + std::size_t size() const { return size_; } + + // Lexicographically compare this string reference to other. + int compare(BasicStringRef other) const { + std::size_t size = size_ < other.size_ ? size_ : other.size_; + int result = std::char_traits::compare(data_, other.data_, size); + if (result == 0) + result = size_ == other.size_ ? 0 : (size_ < other.size_ ? -1 : 1); + return result; + } + + friend bool operator==(BasicStringRef lhs, BasicStringRef rhs) { + return lhs.compare(rhs) == 0; + } + friend bool operator!=(BasicStringRef lhs, BasicStringRef rhs) { + return lhs.compare(rhs) != 0; + } + friend bool operator<(BasicStringRef lhs, BasicStringRef rhs) { + return lhs.compare(rhs) < 0; + } + friend bool operator<=(BasicStringRef lhs, BasicStringRef rhs) { + return lhs.compare(rhs) <= 0; + } + friend bool operator>(BasicStringRef lhs, BasicStringRef rhs) { + return lhs.compare(rhs) > 0; + } + friend bool operator>=(BasicStringRef lhs, BasicStringRef rhs) { + return lhs.compare(rhs) >= 0; + } +}; + +typedef BasicStringRef StringRef; +typedef BasicStringRef WStringRef; + +/** + \rst + A reference to a null terminated string. It can be constructed from a C + string or ``std::string``. + + You can use one of the following typedefs for common character types: + + +-------------+--------------------------+ + | Type | Definition | + +=============+==========================+ + | CStringRef | BasicCStringRef | + +-------------+--------------------------+ + | WCStringRef | BasicCStringRef | + +-------------+--------------------------+ + + This class is most useful as a parameter type to allow passing + different types of strings to a function, for example:: + + template + std::string format(CStringRef format_str, const Args & ... args); + + format("{}", 42); + format(std::string("{}"), 42); + \endrst + */ +template +class BasicCStringRef { + private: + const Char *data_; + + public: + /** Constructs a string reference object from a C string. */ + BasicCStringRef(const Char *s) : data_(s) {} + + /** + \rst + Constructs a string reference from an ``std::string`` object. + \endrst + */ + BasicCStringRef(const std::basic_string &s) : data_(s.c_str()) {} + + /** Returns the pointer to a C string. */ + const Char *c_str() const { return data_; } +}; + +typedef BasicCStringRef CStringRef; +typedef BasicCStringRef WCStringRef; + +/** + A formatting error such as invalid format string. +*/ +class FormatError : public std::runtime_error { + public: + explicit FormatError(CStringRef message) + : std::runtime_error(message.c_str()) {} +}; + +namespace internal { + +// MakeUnsigned::Type gives an unsigned type corresponding to integer type T. +template +struct MakeUnsigned { typedef T Type; }; + +#define FMT_SPECIALIZE_MAKE_UNSIGNED(T, U) \ + template <> \ + struct MakeUnsigned { typedef U Type; } + +FMT_SPECIALIZE_MAKE_UNSIGNED(char, unsigned char); +FMT_SPECIALIZE_MAKE_UNSIGNED(signed char, unsigned char); +FMT_SPECIALIZE_MAKE_UNSIGNED(short, unsigned short); +FMT_SPECIALIZE_MAKE_UNSIGNED(int, unsigned); +FMT_SPECIALIZE_MAKE_UNSIGNED(long, unsigned long); +FMT_SPECIALIZE_MAKE_UNSIGNED(LongLong, ULongLong); + +// Casts nonnegative integer to unsigned. +template +inline typename MakeUnsigned::Type to_unsigned(Int value) { + FMT_ASSERT(value >= 0, "negative value"); + return static_cast::Type>(value); +} + +// The number of characters to store in the MemoryBuffer object itself +// to avoid dynamic memory allocation. +enum { INLINE_BUFFER_SIZE = 500 }; + +#if FMT_SECURE_SCL +// Use checked iterator to avoid warnings on MSVC. +template +inline stdext::checked_array_iterator make_ptr(T *ptr, std::size_t size) { + return stdext::checked_array_iterator(ptr, size); +} +#else +template +inline T *make_ptr(T *ptr, std::size_t) { return ptr; } +#endif +} // namespace internal + +/** + \rst + A buffer supporting a subset of ``std::vector``'s operations. + \endrst + */ +template +class Buffer { + private: + FMT_DISALLOW_COPY_AND_ASSIGN(Buffer); + + protected: + T *ptr_; + std::size_t size_; + std::size_t capacity_; + + Buffer(T *ptr = 0, std::size_t capacity = 0) + : ptr_(ptr), size_(0), capacity_(capacity) {} + + /** + \rst + Increases the buffer capacity to hold at least *size* elements updating + ``ptr_`` and ``capacity_``. + \endrst + */ + virtual void grow(std::size_t size) = 0; + + public: + virtual ~Buffer() {} + + /** Returns the size of this buffer. */ + std::size_t size() const { return size_; } + + /** Returns the capacity of this buffer. */ + std::size_t capacity() const { return capacity_; } + + /** + Resizes the buffer. If T is a POD type new elements may not be initialized. + */ + void resize(std::size_t new_size) { + if (new_size > capacity_) + grow(new_size); + size_ = new_size; + } + + /** + \rst + Reserves space to store at least *capacity* elements. + \endrst + */ + void reserve(std::size_t capacity) { + if (capacity > capacity_) + grow(capacity); + } + + void clear() FMT_NOEXCEPT { size_ = 0; } + + void push_back(const T &value) { + if (size_ == capacity_) + grow(size_ + 1); + ptr_[size_++] = value; + } + + /** Appends data to the end of the buffer. */ + template + void append(const U *begin, const U *end); + + T &operator[](std::size_t index) { return ptr_[index]; } + const T &operator[](std::size_t index) const { return ptr_[index]; } +}; + +template +template +void Buffer::append(const U *begin, const U *end) { + std::size_t new_size = size_ + internal::to_unsigned(end - begin); + if (new_size > capacity_) + grow(new_size); + std::uninitialized_copy(begin, end, + internal::make_ptr(ptr_, capacity_) + size_); + size_ = new_size; +} + +namespace internal { + +// A memory buffer for trivially copyable/constructible types with the first SIZE +// elements stored in the object itself. +template > +class MemoryBuffer : private Allocator, public Buffer { + private: + T data_[SIZE]; + + // Deallocate memory allocated by the buffer. + void deallocate() { + if (this->ptr_ != data_) Allocator::deallocate(this->ptr_, this->capacity_); + } + + protected: + void grow(std::size_t size); + + public: + explicit MemoryBuffer(const Allocator &alloc = Allocator()) + : Allocator(alloc), Buffer(data_, SIZE) {} + ~MemoryBuffer() { deallocate(); } + +#if FMT_USE_RVALUE_REFERENCES + private: + // Move data from other to this buffer. + void move(MemoryBuffer &other) { + Allocator &this_alloc = *this, &other_alloc = other; + this_alloc = std::move(other_alloc); + this->size_ = other.size_; + this->capacity_ = other.capacity_; + if (other.ptr_ == other.data_) { + this->ptr_ = data_; + std::uninitialized_copy(other.data_, other.data_ + this->size_, + make_ptr(data_, this->capacity_)); + } else { + this->ptr_ = other.ptr_; + // Set pointer to the inline array so that delete is not called + // when deallocating. + other.ptr_ = other.data_; + } + } + + public: + MemoryBuffer(MemoryBuffer &&other) { + move(other); + } + + MemoryBuffer &operator=(MemoryBuffer &&other) { + assert(this != &other); + deallocate(); + move(other); + return *this; + } +#endif + + // Returns a copy of the allocator associated with this buffer. + Allocator get_allocator() const { return *this; } +}; + +template +void MemoryBuffer::grow(std::size_t size) { + std::size_t new_capacity = this->capacity_ + this->capacity_ / 2; + if (size > new_capacity) + new_capacity = size; + T *new_ptr = this->allocate(new_capacity); + // The following code doesn't throw, so the raw pointer above doesn't leak. + std::uninitialized_copy(this->ptr_, this->ptr_ + this->size_, + make_ptr(new_ptr, new_capacity)); + std::size_t old_capacity = this->capacity_; + T *old_ptr = this->ptr_; + this->capacity_ = new_capacity; + this->ptr_ = new_ptr; + // deallocate may throw (at least in principle), but it doesn't matter since + // the buffer already uses the new storage and will deallocate it in case + // of exception. + if (old_ptr != data_) + Allocator::deallocate(old_ptr, old_capacity); +} + +// A fixed-size buffer. +template +class FixedBuffer : public fmt::Buffer { + public: + FixedBuffer(Char *array, std::size_t size) : fmt::Buffer(array, size) {} + + protected: + FMT_API void grow(std::size_t size); +}; + +template +class BasicCharTraits { + public: +#if FMT_SECURE_SCL + typedef stdext::checked_array_iterator CharPtr; +#else + typedef Char *CharPtr; +#endif + static Char cast(int value) { return static_cast(value); } +}; + +template +class CharTraits; + +template <> +class CharTraits : public BasicCharTraits { + private: + // Conversion from wchar_t to char is not allowed. + static char convert(wchar_t); + + public: + static char convert(char value) { return value; } + + // Formats a floating-point number. + template + FMT_API static int format_float(char *buffer, std::size_t size, + const char *format, unsigned width, int precision, T value); +}; + +template <> +class CharTraits : public BasicCharTraits { + public: + static wchar_t convert(char value) { return value; } + static wchar_t convert(wchar_t value) { return value; } + + template + FMT_API static int format_float(wchar_t *buffer, std::size_t size, + const wchar_t *format, unsigned width, int precision, T value); +}; + +// Checks if a number is negative - used to avoid warnings. +template +struct SignChecker { + template + static bool is_negative(T value) { return value < 0; } +}; + +template <> +struct SignChecker { + template + static bool is_negative(T) { return false; } +}; + +// Returns true if value is negative, false otherwise. +// Same as (value < 0) but doesn't produce warnings if T is an unsigned type. +template +inline bool is_negative(T value) { + return SignChecker::is_signed>::is_negative(value); +} + +// Selects uint32_t if FitsIn32Bits is true, uint64_t otherwise. +template +struct TypeSelector { typedef uint32_t Type; }; + +template <> +struct TypeSelector { typedef uint64_t Type; }; + +template +struct IntTraits { + // Smallest of uint32_t and uint64_t that is large enough to represent + // all values of T. + typedef typename + TypeSelector::digits <= 32>::Type MainType; +}; + +FMT_API void report_unknown_type(char code, const char *type); + +// Static data is placed in this class template to allow header-only +// configuration. +template +struct FMT_API BasicData { + static const uint32_t POWERS_OF_10_32[]; + static const uint64_t POWERS_OF_10_64[]; + static const char DIGITS[]; +}; + +typedef BasicData<> Data; + +#ifdef FMT_BUILTIN_CLZLL +// Returns the number of decimal digits in n. Leading zeros are not counted +// except for n == 0 in which case count_digits returns 1. +inline unsigned count_digits(uint64_t n) { + // Based on http://graphics.stanford.edu/~seander/bithacks.html#IntegerLog10 + // and the benchmark https://github.com/localvoid/cxx-benchmark-count-digits. + int t = (64 - FMT_BUILTIN_CLZLL(n | 1)) * 1233 >> 12; + return to_unsigned(t) - (n < Data::POWERS_OF_10_64[t]) + 1; +} +#else +// Fallback version of count_digits used when __builtin_clz is not available. +inline unsigned count_digits(uint64_t n) { + unsigned count = 1; + for (;;) { + // Integer division is slow so do it for a group of four digits instead + // of for every digit. The idea comes from the talk by Alexandrescu + // "Three Optimization Tips for C++". See speed-test for a comparison. + if (n < 10) return count; + if (n < 100) return count + 1; + if (n < 1000) return count + 2; + if (n < 10000) return count + 3; + n /= 10000u; + count += 4; + } +} +#endif + +#ifdef FMT_BUILTIN_CLZ +// Optional version of count_digits for better performance on 32-bit platforms. +inline unsigned count_digits(uint32_t n) { + int t = (32 - FMT_BUILTIN_CLZ(n | 1)) * 1233 >> 12; + return to_unsigned(t) - (n < Data::POWERS_OF_10_32[t]) + 1; +} +#endif + +// A functor that doesn't add a thousands separator. +struct NoThousandsSep { + template + void operator()(Char *) {} +}; + +// A functor that adds a thousands separator. +class ThousandsSep { + private: + fmt::StringRef sep_; + + // Index of a decimal digit with the least significant digit having index 0. + unsigned digit_index_; + + public: + explicit ThousandsSep(fmt::StringRef sep) : sep_(sep), digit_index_(0) {} + + template + void operator()(Char *&buffer) { + if (++digit_index_ % 3 != 0) + return; + buffer -= sep_.size(); + std::uninitialized_copy(sep_.data(), sep_.data() + sep_.size(), + internal::make_ptr(buffer, sep_.size())); + } +}; + +// Formats a decimal unsigned integer value writing into buffer. +// thousands_sep is a functor that is called after writing each char to +// add a thousands separator if necessary. +template +inline void format_decimal(Char *buffer, UInt value, unsigned num_digits, + ThousandsSep thousands_sep) { + buffer += num_digits; + while (value >= 100) { + // Integer division is slow so do it for a group of two digits instead + // of for every digit. The idea comes from the talk by Alexandrescu + // "Three Optimization Tips for C++". See speed-test for a comparison. + unsigned index = static_cast((value % 100) * 2); + value /= 100; + *--buffer = Data::DIGITS[index + 1]; + thousands_sep(buffer); + *--buffer = Data::DIGITS[index]; + thousands_sep(buffer); + } + if (value < 10) { + *--buffer = static_cast('0' + value); + return; + } + unsigned index = static_cast(value * 2); + *--buffer = Data::DIGITS[index + 1]; + *--buffer = Data::DIGITS[index]; +} + +template +inline void format_decimal(Char *buffer, UInt value, unsigned num_digits) { + return format_decimal(buffer, value, num_digits, NoThousandsSep()); +} + +#ifndef _WIN32 +# define FMT_USE_WINDOWS_H 0 +#elif !defined(FMT_USE_WINDOWS_H) +# define FMT_USE_WINDOWS_H 1 +#endif + +// Define FMT_USE_WINDOWS_H to 0 to disable use of windows.h. +// All the functionality that relies on it will be disabled too. +#if FMT_USE_WINDOWS_H +// A converter from UTF-8 to UTF-16. +// It is only provided for Windows since other systems support UTF-8 natively. +class UTF8ToUTF16 { + private: + MemoryBuffer buffer_; + + public: + FMT_API explicit UTF8ToUTF16(StringRef s); + operator WStringRef() const { return WStringRef(&buffer_[0], size()); } + size_t size() const { return buffer_.size() - 1; } + const wchar_t *c_str() const { return &buffer_[0]; } + std::wstring str() const { return std::wstring(&buffer_[0], size()); } +}; + +// A converter from UTF-16 to UTF-8. +// It is only provided for Windows since other systems support UTF-8 natively. +class UTF16ToUTF8 { + private: + MemoryBuffer buffer_; + + public: + UTF16ToUTF8() {} + FMT_API explicit UTF16ToUTF8(WStringRef s); + operator StringRef() const { return StringRef(&buffer_[0], size()); } + size_t size() const { return buffer_.size() - 1; } + const char *c_str() const { return &buffer_[0]; } + std::string str() const { return std::string(&buffer_[0], size()); } + + // Performs conversion returning a system error code instead of + // throwing exception on conversion error. This method may still throw + // in case of memory allocation error. + FMT_API int convert(WStringRef s); +}; + +FMT_API void format_windows_error(fmt::Writer &out, int error_code, + fmt::StringRef message) FMT_NOEXCEPT; +#endif + +FMT_API void format_system_error(fmt::Writer &out, int error_code, + fmt::StringRef message) FMT_NOEXCEPT; + +// A formatting argument value. +struct Value { + template + struct StringValue { + const Char *value; + std::size_t size; + }; + + typedef void (*FormatFunc)( + void *formatter, const void *arg, void *format_str_ptr); + + struct CustomValue { + const void *value; + FormatFunc format; + }; + + union { + int int_value; + unsigned uint_value; + LongLong long_long_value; + ULongLong ulong_long_value; + double double_value; + long double long_double_value; + const void *pointer; + StringValue string; + StringValue sstring; + StringValue ustring; + StringValue wstring; + CustomValue custom; + }; + + enum Type { + NONE, NAMED_ARG, + // Integer types should go first, + INT, UINT, LONG_LONG, ULONG_LONG, BOOL, CHAR, LAST_INTEGER_TYPE = CHAR, + // followed by floating-point types. + DOUBLE, LONG_DOUBLE, LAST_NUMERIC_TYPE = LONG_DOUBLE, + CSTRING, STRING, WSTRING, POINTER, CUSTOM + }; +}; + +// A formatting argument. It is a trivially copyable/constructible type to +// allow storage in internal::MemoryBuffer. +struct Arg : Value { + Type type; +}; + +template +struct NamedArg; + +template +struct Null {}; + +// A helper class template to enable or disable overloads taking wide +// characters and strings in MakeValue. +template +struct WCharHelper { + typedef Null Supported; + typedef T Unsupported; +}; + +template +struct WCharHelper { + typedef T Supported; + typedef Null Unsupported; +}; + +typedef char Yes[1]; +typedef char No[2]; + +template +T &get(); + +// These are non-members to workaround an overload resolution bug in bcc32. +Yes &convert(fmt::ULongLong); +No &convert(...); + +template +struct ConvertToIntImpl { + enum { value = ENABLE_CONVERSION }; +}; + +template +struct ConvertToIntImpl2 { + enum { value = false }; +}; + +template +struct ConvertToIntImpl2 { + enum { + // Don't convert numeric types. + value = ConvertToIntImpl::is_specialized>::value + }; +}; + +template +struct ConvertToInt { + enum { enable_conversion = sizeof(convert(get())) == sizeof(Yes) }; + enum { value = ConvertToIntImpl2::value }; +}; + +#define FMT_DISABLE_CONVERSION_TO_INT(Type) \ + template <> \ + struct ConvertToInt { enum { value = 0 }; } + +// Silence warnings about convering float to int. +FMT_DISABLE_CONVERSION_TO_INT(float); +FMT_DISABLE_CONVERSION_TO_INT(double); +FMT_DISABLE_CONVERSION_TO_INT(long double); + +template +struct EnableIf {}; + +template +struct EnableIf { typedef T type; }; + +template +struct Conditional { typedef T type; }; + +template +struct Conditional { typedef F type; }; + +// For bcc32 which doesn't understand ! in template arguments. +template +struct Not { enum { value = 0 }; }; + +template<> +struct Not { enum { value = 1 }; }; + +// Makes an Arg object from any type. +template +class MakeValue : public Arg { + public: + typedef typename Formatter::Char Char; + + private: + // The following two methods are private to disallow formatting of + // arbitrary pointers. If you want to output a pointer cast it to + // "void *" or "const void *". In particular, this forbids formatting + // of "[const] volatile char *" which is printed as bool by iostreams. + // Do not implement! + template + MakeValue(const T *value); + template + MakeValue(T *value); + + // The following methods are private to disallow formatting of wide + // characters and strings into narrow strings as in + // fmt::format("{}", L"test"); + // To fix this, use a wide format string: fmt::format(L"{}", L"test"). +#if !defined(_MSC_VER) || defined(_NATIVE_WCHAR_T_DEFINED) + MakeValue(typename WCharHelper::Unsupported); +#endif + MakeValue(typename WCharHelper::Unsupported); + MakeValue(typename WCharHelper::Unsupported); + MakeValue(typename WCharHelper::Unsupported); + MakeValue(typename WCharHelper::Unsupported); + + void set_string(StringRef str) { + string.value = str.data(); + string.size = str.size(); + } + + void set_string(WStringRef str) { + wstring.value = str.data(); + wstring.size = str.size(); + } + + // Formats an argument of a custom type, such as a user-defined class. + template + static void format_custom_arg( + void *formatter, const void *arg, void *format_str_ptr) { + format(*static_cast(formatter), + *static_cast(format_str_ptr), + *static_cast(arg)); + } + + public: + MakeValue() {} + +#define FMT_MAKE_VALUE_(Type, field, TYPE, rhs) \ + MakeValue(Type value) { field = rhs; } \ + static uint64_t type(Type) { return Arg::TYPE; } + +#define FMT_MAKE_VALUE(Type, field, TYPE) \ + FMT_MAKE_VALUE_(Type, field, TYPE, value) + + FMT_MAKE_VALUE(bool, int_value, BOOL) + FMT_MAKE_VALUE(short, int_value, INT) + FMT_MAKE_VALUE(unsigned short, uint_value, UINT) + FMT_MAKE_VALUE(int, int_value, INT) + FMT_MAKE_VALUE(unsigned, uint_value, UINT) + + MakeValue(long value) { + // To minimize the number of types we need to deal with, long is + // translated either to int or to long long depending on its size. + if (check(sizeof(long) == sizeof(int))) + int_value = static_cast(value); + else + long_long_value = value; + } + static uint64_t type(long) { + return sizeof(long) == sizeof(int) ? Arg::INT : Arg::LONG_LONG; + } + + MakeValue(unsigned long value) { + if (check(sizeof(unsigned long) == sizeof(unsigned))) + uint_value = static_cast(value); + else + ulong_long_value = value; + } + static uint64_t type(unsigned long) { + return sizeof(unsigned long) == sizeof(unsigned) ? + Arg::UINT : Arg::ULONG_LONG; + } + + FMT_MAKE_VALUE(LongLong, long_long_value, LONG_LONG) + FMT_MAKE_VALUE(ULongLong, ulong_long_value, ULONG_LONG) + FMT_MAKE_VALUE(float, double_value, DOUBLE) + FMT_MAKE_VALUE(double, double_value, DOUBLE) + FMT_MAKE_VALUE(long double, long_double_value, LONG_DOUBLE) + FMT_MAKE_VALUE(signed char, int_value, INT) + FMT_MAKE_VALUE(unsigned char, uint_value, UINT) + FMT_MAKE_VALUE(char, int_value, CHAR) + +#if !defined(_MSC_VER) || defined(_NATIVE_WCHAR_T_DEFINED) + MakeValue(typename WCharHelper::Supported value) { + int_value = value; + } + static uint64_t type(wchar_t) { return Arg::CHAR; } +#endif + +#define FMT_MAKE_STR_VALUE(Type, TYPE) \ + MakeValue(Type value) { set_string(value); } \ + static uint64_t type(Type) { return Arg::TYPE; } + + FMT_MAKE_VALUE(char *, string.value, CSTRING) + FMT_MAKE_VALUE(const char *, string.value, CSTRING) + FMT_MAKE_VALUE(const signed char *, sstring.value, CSTRING) + FMT_MAKE_VALUE(const unsigned char *, ustring.value, CSTRING) + FMT_MAKE_STR_VALUE(const std::string &, STRING) + FMT_MAKE_STR_VALUE(StringRef, STRING) + FMT_MAKE_VALUE_(CStringRef, string.value, CSTRING, value.c_str()) + +#define FMT_MAKE_WSTR_VALUE(Type, TYPE) \ + MakeValue(typename WCharHelper::Supported value) { \ + set_string(value); \ + } \ + static uint64_t type(Type) { return Arg::TYPE; } + + FMT_MAKE_WSTR_VALUE(wchar_t *, WSTRING) + FMT_MAKE_WSTR_VALUE(const wchar_t *, WSTRING) + FMT_MAKE_WSTR_VALUE(const std::wstring &, WSTRING) + FMT_MAKE_WSTR_VALUE(WStringRef, WSTRING) + + FMT_MAKE_VALUE(void *, pointer, POINTER) + FMT_MAKE_VALUE(const void *, pointer, POINTER) + + template + MakeValue(const T &value, + typename EnableIf::value>::value, int>::type = 0) { + custom.value = &value; + custom.format = &format_custom_arg; + } + + template + MakeValue(const T &value, + typename EnableIf::value, int>::type = 0) { + int_value = value; + } + + template + static uint64_t type(const T &) { + return ConvertToInt::value ? Arg::INT : Arg::CUSTOM; + } + + // Additional template param `Char_` is needed here because make_type always + // uses char. + template + MakeValue(const NamedArg &value) { pointer = &value; } + + template + static uint64_t type(const NamedArg &) { return Arg::NAMED_ARG; } +}; + +template +class MakeArg : public Arg { +public: + MakeArg() { + type = Arg::NONE; + } + + template + MakeArg(const T &value) + : Arg(MakeValue(value)) { + type = static_cast(MakeValue::type(value)); + } +}; + +template +struct NamedArg : Arg { + BasicStringRef name; + + template + NamedArg(BasicStringRef argname, const T &value) + : Arg(MakeArg< BasicFormatter >(value)), name(argname) {} +}; + +class RuntimeError : public std::runtime_error { + protected: + RuntimeError() : std::runtime_error("") {} +}; + +template +class PrintfArgFormatter; + +template +class ArgMap; +} // namespace internal + +/** An argument list. */ +class ArgList { + private: + // To reduce compiled code size per formatting function call, types of first + // MAX_PACKED_ARGS arguments are passed in the types_ field. + uint64_t types_; + union { + // If the number of arguments is less than MAX_PACKED_ARGS, the argument + // values are stored in values_, otherwise they are stored in args_. + // This is done to reduce compiled code size as storing larger objects + // may require more code (at least on x86-64) even if the same amount of + // data is actually copied to stack. It saves ~10% on the bloat test. + const internal::Value *values_; + const internal::Arg *args_; + }; + + internal::Arg::Type type(unsigned index) const { + unsigned shift = index * 4; + uint64_t mask = 0xf; + return static_cast( + (types_ & (mask << shift)) >> shift); + } + + template + friend class internal::ArgMap; + + public: + // Maximum number of arguments with packed types. + enum { MAX_PACKED_ARGS = 16 }; + + ArgList() : types_(0) {} + + ArgList(ULongLong types, const internal::Value *values) + : types_(types), values_(values) {} + ArgList(ULongLong types, const internal::Arg *args) + : types_(types), args_(args) {} + + /** Returns the argument at specified index. */ + internal::Arg operator[](unsigned index) const { + using internal::Arg; + Arg arg; + bool use_values = type(MAX_PACKED_ARGS - 1) == Arg::NONE; + if (index < MAX_PACKED_ARGS) { + Arg::Type arg_type = type(index); + internal::Value &val = arg; + if (arg_type != Arg::NONE) + val = use_values ? values_[index] : args_[index]; + arg.type = arg_type; + return arg; + } + if (use_values) { + // The index is greater than the number of arguments that can be stored + // in values, so return a "none" argument. + arg.type = Arg::NONE; + return arg; + } + for (unsigned i = MAX_PACKED_ARGS; i <= index; ++i) { + if (args_[i].type == Arg::NONE) + return args_[i]; + } + return args_[index]; + } +}; + +#define FMT_DISPATCH(call) static_cast(this)->call + +/** + \rst + An argument visitor based on the `curiously recurring template pattern + `_. + + To use `~fmt::ArgVisitor` define a subclass that implements some or all of the + visit methods with the same signatures as the methods in `~fmt::ArgVisitor`, + for example, `~fmt::ArgVisitor::visit_int()`. + Pass the subclass as the *Impl* template parameter. Then calling + `~fmt::ArgVisitor::visit` for some argument will dispatch to a visit method + specific to the argument type. For example, if the argument type is + ``double`` then the `~fmt::ArgVisitor::visit_double()` method of a subclass + will be called. If the subclass doesn't contain a method with this signature, + then a corresponding method of `~fmt::ArgVisitor` will be called. + + **Example**:: + + class MyArgVisitor : public fmt::ArgVisitor { + public: + void visit_int(int value) { fmt::print("{}", value); } + void visit_double(double value) { fmt::print("{}", value ); } + }; + \endrst + */ +template +class ArgVisitor { + private: + typedef internal::Arg Arg; + + public: + void report_unhandled_arg() {} + + Result visit_unhandled_arg() { + FMT_DISPATCH(report_unhandled_arg()); + return Result(); + } + + /** Visits an ``int`` argument. **/ + Result visit_int(int value) { + return FMT_DISPATCH(visit_any_int(value)); + } + + /** Visits a ``long long`` argument. **/ + Result visit_long_long(LongLong value) { + return FMT_DISPATCH(visit_any_int(value)); + } + + /** Visits an ``unsigned`` argument. **/ + Result visit_uint(unsigned value) { + return FMT_DISPATCH(visit_any_int(value)); + } + + /** Visits an ``unsigned long long`` argument. **/ + Result visit_ulong_long(ULongLong value) { + return FMT_DISPATCH(visit_any_int(value)); + } + + /** Visits a ``bool`` argument. **/ + Result visit_bool(bool value) { + return FMT_DISPATCH(visit_any_int(value)); + } + + /** Visits a ``char`` or ``wchar_t`` argument. **/ + Result visit_char(int value) { + return FMT_DISPATCH(visit_any_int(value)); + } + + /** Visits an argument of any integral type. **/ + template + Result visit_any_int(T) { + return FMT_DISPATCH(visit_unhandled_arg()); + } + + /** Visits a ``double`` argument. **/ + Result visit_double(double value) { + return FMT_DISPATCH(visit_any_double(value)); + } + + /** Visits a ``long double`` argument. **/ + Result visit_long_double(long double value) { + return FMT_DISPATCH(visit_any_double(value)); + } + + /** Visits a ``double`` or ``long double`` argument. **/ + template + Result visit_any_double(T) { + return FMT_DISPATCH(visit_unhandled_arg()); + } + + /** Visits a null-terminated C string (``const char *``) argument. **/ + Result visit_cstring(const char *) { + return FMT_DISPATCH(visit_unhandled_arg()); + } + + /** Visits a string argument. **/ + Result visit_string(Arg::StringValue) { + return FMT_DISPATCH(visit_unhandled_arg()); + } + + /** Visits a wide string argument. **/ + Result visit_wstring(Arg::StringValue) { + return FMT_DISPATCH(visit_unhandled_arg()); + } + + /** Visits a pointer argument. **/ + Result visit_pointer(const void *) { + return FMT_DISPATCH(visit_unhandled_arg()); + } + + /** Visits an argument of a custom (user-defined) type. **/ + Result visit_custom(Arg::CustomValue) { + return FMT_DISPATCH(visit_unhandled_arg()); + } + + /** + \rst + Visits an argument dispatching to the appropriate visit method based on + the argument type. For example, if the argument type is ``double`` then + the `~fmt::ArgVisitor::visit_double()` method of the *Impl* class will be + called. + \endrst + */ + Result visit(const Arg &arg) { + switch (arg.type) { + default: + FMT_ASSERT(false, "invalid argument type"); + return Result(); + case Arg::INT: + return FMT_DISPATCH(visit_int(arg.int_value)); + case Arg::UINT: + return FMT_DISPATCH(visit_uint(arg.uint_value)); + case Arg::LONG_LONG: + return FMT_DISPATCH(visit_long_long(arg.long_long_value)); + case Arg::ULONG_LONG: + return FMT_DISPATCH(visit_ulong_long(arg.ulong_long_value)); + case Arg::BOOL: + return FMT_DISPATCH(visit_bool(arg.int_value != 0)); + case Arg::CHAR: + return FMT_DISPATCH(visit_char(arg.int_value)); + case Arg::DOUBLE: + return FMT_DISPATCH(visit_double(arg.double_value)); + case Arg::LONG_DOUBLE: + return FMT_DISPATCH(visit_long_double(arg.long_double_value)); + case Arg::CSTRING: + return FMT_DISPATCH(visit_cstring(arg.string.value)); + case Arg::STRING: + return FMT_DISPATCH(visit_string(arg.string)); + case Arg::WSTRING: + return FMT_DISPATCH(visit_wstring(arg.wstring)); + case Arg::POINTER: + return FMT_DISPATCH(visit_pointer(arg.pointer)); + case Arg::CUSTOM: + return FMT_DISPATCH(visit_custom(arg.custom)); + } + } +}; + +enum Alignment { + ALIGN_DEFAULT, ALIGN_LEFT, ALIGN_RIGHT, ALIGN_CENTER, ALIGN_NUMERIC +}; + +// Flags. +enum { + SIGN_FLAG = 1, PLUS_FLAG = 2, MINUS_FLAG = 4, HASH_FLAG = 8, + CHAR_FLAG = 0x10 // Argument has char type - used in error reporting. +}; + +// An empty format specifier. +struct EmptySpec {}; + +// A type specifier. +template +struct TypeSpec : EmptySpec { + Alignment align() const { return ALIGN_DEFAULT; } + unsigned width() const { return 0; } + int precision() const { return -1; } + bool flag(unsigned) const { return false; } + char type() const { return TYPE; } + char fill() const { return ' '; } +}; + +// A width specifier. +struct WidthSpec { + unsigned width_; + // Fill is always wchar_t and cast to char if necessary to avoid having + // two specialization of WidthSpec and its subclasses. + wchar_t fill_; + + WidthSpec(unsigned width, wchar_t fill) : width_(width), fill_(fill) {} + + unsigned width() const { return width_; } + wchar_t fill() const { return fill_; } +}; + +// An alignment specifier. +struct AlignSpec : WidthSpec { + Alignment align_; + + AlignSpec(unsigned width, wchar_t fill, Alignment align = ALIGN_DEFAULT) + : WidthSpec(width, fill), align_(align) {} + + Alignment align() const { return align_; } + + int precision() const { return -1; } +}; + +// An alignment and type specifier. +template +struct AlignTypeSpec : AlignSpec { + AlignTypeSpec(unsigned width, wchar_t fill) : AlignSpec(width, fill) {} + + bool flag(unsigned) const { return false; } + char type() const { return TYPE; } +}; + +// A full format specifier. +struct FormatSpec : AlignSpec { + unsigned flags_; + int precision_; + char type_; + + FormatSpec( + unsigned width = 0, char type = 0, wchar_t fill = ' ') + : AlignSpec(width, fill), flags_(0), precision_(-1), type_(type) {} + + bool flag(unsigned f) const { return (flags_ & f) != 0; } + int precision() const { return precision_; } + char type() const { return type_; } +}; + +// An integer format specifier. +template , typename Char = char> +class IntFormatSpec : public SpecT { + private: + T value_; + + public: + IntFormatSpec(T val, const SpecT &spec = SpecT()) + : SpecT(spec), value_(val) {} + + T value() const { return value_; } +}; + +// A string format specifier. +template +class StrFormatSpec : public AlignSpec { + private: + const Char *str_; + + public: + template + StrFormatSpec(const Char *str, unsigned width, FillChar fill) + : AlignSpec(width, fill), str_(str) { + internal::CharTraits::convert(FillChar()); + } + + const Char *str() const { return str_; } +}; + +/** + Returns an integer format specifier to format the value in base 2. + */ +IntFormatSpec > bin(int value); + +/** + Returns an integer format specifier to format the value in base 8. + */ +IntFormatSpec > oct(int value); + +/** + Returns an integer format specifier to format the value in base 16 using + lower-case letters for the digits above 9. + */ +IntFormatSpec > hex(int value); + +/** + Returns an integer formatter format specifier to format in base 16 using + upper-case letters for the digits above 9. + */ +IntFormatSpec > hexu(int value); + +/** + \rst + Returns an integer format specifier to pad the formatted argument with the + fill character to the specified width using the default (right) numeric + alignment. + + **Example**:: + + MemoryWriter out; + out << pad(hex(0xcafe), 8, '0'); + // out.str() == "0000cafe" + + \endrst + */ +template +IntFormatSpec, Char> pad( + int value, unsigned width, Char fill = ' '); + +#define FMT_DEFINE_INT_FORMATTERS(TYPE) \ +inline IntFormatSpec > bin(TYPE value) { \ + return IntFormatSpec >(value, TypeSpec<'b'>()); \ +} \ + \ +inline IntFormatSpec > oct(TYPE value) { \ + return IntFormatSpec >(value, TypeSpec<'o'>()); \ +} \ + \ +inline IntFormatSpec > hex(TYPE value) { \ + return IntFormatSpec >(value, TypeSpec<'x'>()); \ +} \ + \ +inline IntFormatSpec > hexu(TYPE value) { \ + return IntFormatSpec >(value, TypeSpec<'X'>()); \ +} \ + \ +template \ +inline IntFormatSpec > pad( \ + IntFormatSpec > f, unsigned width) { \ + return IntFormatSpec >( \ + f.value(), AlignTypeSpec(width, ' ')); \ +} \ + \ +/* For compatibility with older compilers we provide two overloads for pad, */ \ +/* one that takes a fill character and one that doesn't. In the future this */ \ +/* can be replaced with one overload making the template argument Char */ \ +/* default to char (C++11). */ \ +template \ +inline IntFormatSpec, Char> pad( \ + IntFormatSpec, Char> f, \ + unsigned width, Char fill) { \ + return IntFormatSpec, Char>( \ + f.value(), AlignTypeSpec(width, fill)); \ +} \ + \ +inline IntFormatSpec > pad( \ + TYPE value, unsigned width) { \ + return IntFormatSpec >( \ + value, AlignTypeSpec<0>(width, ' ')); \ +} \ + \ +template \ +inline IntFormatSpec, Char> pad( \ + TYPE value, unsigned width, Char fill) { \ + return IntFormatSpec, Char>( \ + value, AlignTypeSpec<0>(width, fill)); \ +} + +FMT_DEFINE_INT_FORMATTERS(int) +FMT_DEFINE_INT_FORMATTERS(long) +FMT_DEFINE_INT_FORMATTERS(unsigned) +FMT_DEFINE_INT_FORMATTERS(unsigned long) +FMT_DEFINE_INT_FORMATTERS(LongLong) +FMT_DEFINE_INT_FORMATTERS(ULongLong) + +/** + \rst + Returns a string formatter that pads the formatted argument with the fill + character to the specified width using the default (left) string alignment. + + **Example**:: + + std::string s = str(MemoryWriter() << pad("abc", 8)); + // s == "abc " + + \endrst + */ +template +inline StrFormatSpec pad( + const Char *str, unsigned width, Char fill = ' ') { + return StrFormatSpec(str, width, fill); +} + +inline StrFormatSpec pad( + const wchar_t *str, unsigned width, char fill = ' ') { + return StrFormatSpec(str, width, fill); +} + +namespace internal { + +template +class ArgMap { + private: + typedef std::vector< + std::pair, internal::Arg> > MapType; + typedef typename MapType::value_type Pair; + + MapType map_; + + public: + FMT_API void init(const ArgList &args); + + const internal::Arg* find(const fmt::BasicStringRef &name) const { + // The list is unsorted, so just return the first matching name. + for (typename MapType::const_iterator it = map_.begin(), end = map_.end(); + it != end; ++it) { + if (it->first == name) + return &it->second; + } + return 0; + } +}; + +template +class ArgFormatterBase : public ArgVisitor { + private: + BasicWriter &writer_; + FormatSpec &spec_; + + FMT_DISALLOW_COPY_AND_ASSIGN(ArgFormatterBase); + + void write_pointer(const void *p) { + spec_.flags_ = HASH_FLAG; + spec_.type_ = 'x'; + writer_.write_int(reinterpret_cast(p), spec_); + } + + protected: + BasicWriter &writer() { return writer_; } + FormatSpec &spec() { return spec_; } + + void write(bool value) { + const char *str_value = value ? "true" : "false"; + Arg::StringValue str = { str_value, std::strlen(str_value) }; + writer_.write_str(str, spec_); + } + + void write(const char *value) { + Arg::StringValue str = {value, value != 0 ? std::strlen(value) : 0}; + writer_.write_str(str, spec_); + } + + public: + ArgFormatterBase(BasicWriter &w, FormatSpec &s) + : writer_(w), spec_(s) {} + + template + void visit_any_int(T value) { writer_.write_int(value, spec_); } + + template + void visit_any_double(T value) { writer_.write_double(value, spec_); } + + void visit_bool(bool value) { + if (spec_.type_) + return visit_any_int(value); + write(value); + } + + void visit_char(int value) { + if (spec_.type_ && spec_.type_ != 'c') { + spec_.flags_ |= CHAR_FLAG; + writer_.write_int(value, spec_); + return; + } + if (spec_.align_ == ALIGN_NUMERIC || spec_.flags_ != 0) + FMT_THROW(FormatError("invalid format specifier for char")); + typedef typename BasicWriter::CharPtr CharPtr; + Char fill = internal::CharTraits::cast(spec_.fill()); + CharPtr out = CharPtr(); + const unsigned CHAR_WIDTH = 1; + if (spec_.width_ > CHAR_WIDTH) { + out = writer_.grow_buffer(spec_.width_); + if (spec_.align_ == ALIGN_RIGHT) { + std::uninitialized_fill_n(out, spec_.width_ - CHAR_WIDTH, fill); + out += spec_.width_ - CHAR_WIDTH; + } else if (spec_.align_ == ALIGN_CENTER) { + out = writer_.fill_padding(out, spec_.width_, + internal::check(CHAR_WIDTH), fill); + } else { + std::uninitialized_fill_n(out + CHAR_WIDTH, + spec_.width_ - CHAR_WIDTH, fill); + } + } else { + out = writer_.grow_buffer(CHAR_WIDTH); + } + *out = internal::CharTraits::cast(value); + } + + void visit_cstring(const char *value) { + if (spec_.type_ == 'p') + return write_pointer(value); + write(value); + } + + void visit_string(Arg::StringValue value) { + writer_.write_str(value, spec_); + } + + using ArgVisitor::visit_wstring; + + void visit_wstring(Arg::StringValue value) { + writer_.write_str(value, spec_); + } + + void visit_pointer(const void *value) { + if (spec_.type_ && spec_.type_ != 'p') + report_unknown_type(spec_.type_, "pointer"); + write_pointer(value); + } +}; + +class FormatterBase { + private: + ArgList args_; + int next_arg_index_; + + // Returns the argument with specified index. + FMT_API Arg do_get_arg(unsigned arg_index, const char *&error); + + protected: + const ArgList &args() const { return args_; } + + explicit FormatterBase(const ArgList &args) { + args_ = args; + next_arg_index_ = 0; + } + + // Returns the next argument. + Arg next_arg(const char *&error) { + if (next_arg_index_ >= 0) + return do_get_arg(internal::to_unsigned(next_arg_index_++), error); + error = "cannot switch from manual to automatic argument indexing"; + return Arg(); + } + + // Checks if manual indexing is used and returns the argument with + // specified index. + Arg get_arg(unsigned arg_index, const char *&error) { + return check_no_auto_index(error) ? do_get_arg(arg_index, error) : Arg(); + } + + bool check_no_auto_index(const char *&error) { + if (next_arg_index_ > 0) { + error = "cannot switch from automatic to manual argument indexing"; + return false; + } + next_arg_index_ = -1; + return true; + } + + template + void write(BasicWriter &w, const Char *start, const Char *end) { + if (start != end) + w << BasicStringRef(start, internal::to_unsigned(end - start)); + } +}; + +// A printf formatter. +template +class PrintfFormatter : private FormatterBase { + private: + void parse_flags(FormatSpec &spec, const Char *&s); + + // Returns the argument with specified index or, if arg_index is equal + // to the maximum unsigned value, the next argument. + Arg get_arg(const Char *s, + unsigned arg_index = (std::numeric_limits::max)()); + + // Parses argument index, flags and width and returns the argument index. + unsigned parse_header(const Char *&s, FormatSpec &spec); + + public: + explicit PrintfFormatter(const ArgList &args) : FormatterBase(args) {} + FMT_API void format(BasicWriter &writer, + BasicCStringRef format_str); +}; +} // namespace internal + +/** + \rst + An argument formatter based on the `curiously recurring template pattern + `_. + + To use `~fmt::BasicArgFormatter` define a subclass that implements some or + all of the visit methods with the same signatures as the methods in + `~fmt::ArgVisitor`, for example, `~fmt::ArgVisitor::visit_int()`. + Pass the subclass as the *Impl* template parameter. When a formatting + function processes an argument, it will dispatch to a visit method + specific to the argument type. For example, if the argument type is + ``double`` then the `~fmt::ArgVisitor::visit_double()` method of a subclass + will be called. If the subclass doesn't contain a method with this signature, + then a corresponding method of `~fmt::BasicArgFormatter` or its superclass + will be called. + \endrst + */ +template +class BasicArgFormatter : public internal::ArgFormatterBase { + private: + BasicFormatter &formatter_; + const Char *format_; + + public: + /** + \rst + Constructs an argument formatter object. + *formatter* is a reference to the main formatter object, *spec* contains + format specifier information for standard argument types, and *fmt* points + to the part of the format string being parsed for custom argument types. + \endrst + */ + BasicArgFormatter(BasicFormatter &formatter, + FormatSpec &spec, const Char *fmt) + : internal::ArgFormatterBase(formatter.writer(), spec), + formatter_(formatter), format_(fmt) {} + + /** Formats argument of a custom (user-defined) type. */ + void visit_custom(internal::Arg::CustomValue c) { + c.format(&formatter_, c.value, &format_); + } +}; + +/** The default argument formatter. */ +template +class ArgFormatter : public BasicArgFormatter, Char> { + public: + /** Constructs an argument formatter object. */ + ArgFormatter(BasicFormatter &formatter, + FormatSpec &spec, const Char *fmt) + : BasicArgFormatter, Char>(formatter, spec, fmt) {} +}; + +/** This template formats data and writes the output to a writer. */ +template +class BasicFormatter : private internal::FormatterBase { + public: + /** The character type for the output. */ + typedef CharType Char; + + private: + BasicWriter &writer_; + internal::ArgMap map_; + + FMT_DISALLOW_COPY_AND_ASSIGN(BasicFormatter); + + using internal::FormatterBase::get_arg; + + // Checks if manual indexing is used and returns the argument with + // specified name. + internal::Arg get_arg(BasicStringRef arg_name, const char *&error); + + // Parses argument index and returns corresponding argument. + internal::Arg parse_arg_index(const Char *&s); + + // Parses argument name and returns corresponding argument. + internal::Arg parse_arg_name(const Char *&s); + + public: + /** + \rst + Constructs a ``BasicFormatter`` object. References to the arguments and + the writer are stored in the formatter object so make sure they have + appropriate lifetimes. + \endrst + */ + BasicFormatter(const ArgList &args, BasicWriter &w) + : internal::FormatterBase(args), writer_(w) {} + + /** Returns a reference to the writer associated with this formatter. */ + BasicWriter &writer() { return writer_; } + + /** Formats stored arguments and writes the output to the writer. */ + void format(BasicCStringRef format_str); + + // Formats a single argument and advances format_str, a format string pointer. + const Char *format(const Char *&format_str, const internal::Arg &arg); +}; + +// Generates a comma-separated list with results of applying f to +// numbers 0..n-1. +# define FMT_GEN(n, f) FMT_GEN##n(f) +# define FMT_GEN1(f) f(0) +# define FMT_GEN2(f) FMT_GEN1(f), f(1) +# define FMT_GEN3(f) FMT_GEN2(f), f(2) +# define FMT_GEN4(f) FMT_GEN3(f), f(3) +# define FMT_GEN5(f) FMT_GEN4(f), f(4) +# define FMT_GEN6(f) FMT_GEN5(f), f(5) +# define FMT_GEN7(f) FMT_GEN6(f), f(6) +# define FMT_GEN8(f) FMT_GEN7(f), f(7) +# define FMT_GEN9(f) FMT_GEN8(f), f(8) +# define FMT_GEN10(f) FMT_GEN9(f), f(9) +# define FMT_GEN11(f) FMT_GEN10(f), f(10) +# define FMT_GEN12(f) FMT_GEN11(f), f(11) +# define FMT_GEN13(f) FMT_GEN12(f), f(12) +# define FMT_GEN14(f) FMT_GEN13(f), f(13) +# define FMT_GEN15(f) FMT_GEN14(f), f(14) + +namespace internal { +inline uint64_t make_type() { return 0; } + +template +inline uint64_t make_type(const T &arg) { + return MakeValue< BasicFormatter >::type(arg); +} + +template +struct ArgArray; + +template +struct ArgArray { + typedef Value Type[N > 0 ? N : 1]; + + template + static Value make(const T &value) { +#ifdef __clang__ + Value result = MakeValue(value); + // Workaround a bug in Apple LLVM version 4.2 (clang-425.0.28) of clang: + // https://github.com/fmtlib/fmt/issues/276 + (void)result.custom.format; + return result; +#else + return MakeValue(value); +#endif + } +}; + +template +struct ArgArray { + typedef Arg Type[N + 1]; // +1 for the list end Arg::NONE + + template + static Arg make(const T &value) { return MakeArg(value); } +}; + +#if FMT_USE_VARIADIC_TEMPLATES +template +inline uint64_t make_type(const Arg &first, const Args & ... tail) { + return make_type(first) | (make_type(tail...) << 4); +} + +#else + +struct ArgType { + uint64_t type; + + ArgType() : type(0) {} + + template + ArgType(const T &arg) : type(make_type(arg)) {} +}; + +# define FMT_ARG_TYPE_DEFAULT(n) ArgType t##n = ArgType() + +inline uint64_t make_type(FMT_GEN15(FMT_ARG_TYPE_DEFAULT)) { + return t0.type | (t1.type << 4) | (t2.type << 8) | (t3.type << 12) | + (t4.type << 16) | (t5.type << 20) | (t6.type << 24) | (t7.type << 28) | + (t8.type << 32) | (t9.type << 36) | (t10.type << 40) | (t11.type << 44) | + (t12.type << 48) | (t13.type << 52) | (t14.type << 56); +} +#endif +} // namespace internal + +# define FMT_MAKE_TEMPLATE_ARG(n) typename T##n +# define FMT_MAKE_ARG_TYPE(n) T##n +# define FMT_MAKE_ARG(n) const T##n &v##n +# define FMT_ASSIGN_char(n) \ + arr[n] = fmt::internal::MakeValue< fmt::BasicFormatter >(v##n) +# define FMT_ASSIGN_wchar_t(n) \ + arr[n] = fmt::internal::MakeValue< fmt::BasicFormatter >(v##n) + +#if FMT_USE_VARIADIC_TEMPLATES +// Defines a variadic function returning void. +# define FMT_VARIADIC_VOID(func, arg_type) \ + template \ + void func(arg_type arg0, const Args & ... args) { \ + typedef fmt::internal::ArgArray ArgArray; \ + typename ArgArray::Type array{ \ + ArgArray::template make >(args)...}; \ + func(arg0, fmt::ArgList(fmt::internal::make_type(args...), array)); \ + } + +// Defines a variadic constructor. +# define FMT_VARIADIC_CTOR(ctor, func, arg0_type, arg1_type) \ + template \ + ctor(arg0_type arg0, arg1_type arg1, const Args & ... args) { \ + typedef fmt::internal::ArgArray ArgArray; \ + typename ArgArray::Type array{ \ + ArgArray::template make >(args)...}; \ + func(arg0, arg1, fmt::ArgList(fmt::internal::make_type(args...), array)); \ + } + +#else + +# define FMT_MAKE_REF(n) \ + fmt::internal::MakeValue< fmt::BasicFormatter >(v##n) +# define FMT_MAKE_REF2(n) v##n + +// Defines a wrapper for a function taking one argument of type arg_type +// and n additional arguments of arbitrary types. +# define FMT_WRAP1(func, arg_type, n) \ + template \ + inline void func(arg_type arg1, FMT_GEN(n, FMT_MAKE_ARG)) { \ + const fmt::internal::ArgArray::Type array = {FMT_GEN(n, FMT_MAKE_REF)}; \ + func(arg1, fmt::ArgList( \ + fmt::internal::make_type(FMT_GEN(n, FMT_MAKE_REF2)), array)); \ + } + +// Emulates a variadic function returning void on a pre-C++11 compiler. +# define FMT_VARIADIC_VOID(func, arg_type) \ + inline void func(arg_type arg) { func(arg, fmt::ArgList()); } \ + FMT_WRAP1(func, arg_type, 1) FMT_WRAP1(func, arg_type, 2) \ + FMT_WRAP1(func, arg_type, 3) FMT_WRAP1(func, arg_type, 4) \ + FMT_WRAP1(func, arg_type, 5) FMT_WRAP1(func, arg_type, 6) \ + FMT_WRAP1(func, arg_type, 7) FMT_WRAP1(func, arg_type, 8) \ + FMT_WRAP1(func, arg_type, 9) FMT_WRAP1(func, arg_type, 10) + +# define FMT_CTOR(ctor, func, arg0_type, arg1_type, n) \ + template \ + ctor(arg0_type arg0, arg1_type arg1, FMT_GEN(n, FMT_MAKE_ARG)) { \ + const fmt::internal::ArgArray::Type array = {FMT_GEN(n, FMT_MAKE_REF)}; \ + func(arg0, arg1, fmt::ArgList( \ + fmt::internal::make_type(FMT_GEN(n, FMT_MAKE_REF2)), array)); \ + } + +// Emulates a variadic constructor on a pre-C++11 compiler. +# define FMT_VARIADIC_CTOR(ctor, func, arg0_type, arg1_type) \ + FMT_CTOR(ctor, func, arg0_type, arg1_type, 1) \ + FMT_CTOR(ctor, func, arg0_type, arg1_type, 2) \ + FMT_CTOR(ctor, func, arg0_type, arg1_type, 3) \ + FMT_CTOR(ctor, func, arg0_type, arg1_type, 4) \ + FMT_CTOR(ctor, func, arg0_type, arg1_type, 5) \ + FMT_CTOR(ctor, func, arg0_type, arg1_type, 6) \ + FMT_CTOR(ctor, func, arg0_type, arg1_type, 7) \ + FMT_CTOR(ctor, func, arg0_type, arg1_type, 8) \ + FMT_CTOR(ctor, func, arg0_type, arg1_type, 9) \ + FMT_CTOR(ctor, func, arg0_type, arg1_type, 10) +#endif + +// Generates a comma-separated list with results of applying f to pairs +// (argument, index). +#define FMT_FOR_EACH1(f, x0) f(x0, 0) +#define FMT_FOR_EACH2(f, x0, x1) \ + FMT_FOR_EACH1(f, x0), f(x1, 1) +#define FMT_FOR_EACH3(f, x0, x1, x2) \ + FMT_FOR_EACH2(f, x0 ,x1), f(x2, 2) +#define FMT_FOR_EACH4(f, x0, x1, x2, x3) \ + FMT_FOR_EACH3(f, x0, x1, x2), f(x3, 3) +#define FMT_FOR_EACH5(f, x0, x1, x2, x3, x4) \ + FMT_FOR_EACH4(f, x0, x1, x2, x3), f(x4, 4) +#define FMT_FOR_EACH6(f, x0, x1, x2, x3, x4, x5) \ + FMT_FOR_EACH5(f, x0, x1, x2, x3, x4), f(x5, 5) +#define FMT_FOR_EACH7(f, x0, x1, x2, x3, x4, x5, x6) \ + FMT_FOR_EACH6(f, x0, x1, x2, x3, x4, x5), f(x6, 6) +#define FMT_FOR_EACH8(f, x0, x1, x2, x3, x4, x5, x6, x7) \ + FMT_FOR_EACH7(f, x0, x1, x2, x3, x4, x5, x6), f(x7, 7) +#define FMT_FOR_EACH9(f, x0, x1, x2, x3, x4, x5, x6, x7, x8) \ + FMT_FOR_EACH8(f, x0, x1, x2, x3, x4, x5, x6, x7), f(x8, 8) +#define FMT_FOR_EACH10(f, x0, x1, x2, x3, x4, x5, x6, x7, x8, x9) \ + FMT_FOR_EACH9(f, x0, x1, x2, x3, x4, x5, x6, x7, x8), f(x9, 9) + +/** + An error returned by an operating system or a language runtime, + for example a file opening error. +*/ +class SystemError : public internal::RuntimeError { + private: + void init(int err_code, CStringRef format_str, ArgList args); + + protected: + int error_code_; + + typedef char Char; // For FMT_VARIADIC_CTOR. + + SystemError() {} + + public: + /** + \rst + Constructs a :class:`fmt::SystemError` object with the description + of the form + + .. parsed-literal:: + **: ** + + where ** is the formatted message and ** is + the system message corresponding to the error code. + *error_code* is a system error code as given by ``errno``. + If *error_code* is not a valid error code such as -1, the system message + may look like "Unknown error -1" and is platform-dependent. + + **Example**:: + + // This throws a SystemError with the description + // cannot open file 'madeup': No such file or directory + // or similar (system message may vary). + const char *filename = "madeup"; + std::FILE *file = std::fopen(filename, "r"); + if (!file) + throw fmt::SystemError(errno, "cannot open file '{}'", filename); + \endrst + */ + SystemError(int error_code, CStringRef message) { + init(error_code, message, ArgList()); + } + FMT_VARIADIC_CTOR(SystemError, init, int, CStringRef) + + int error_code() const { return error_code_; } +}; + +/** + \rst + This template provides operations for formatting and writing data into + a character stream. The output is stored in a buffer provided by a subclass + such as :class:`fmt::BasicMemoryWriter`. + + You can use one of the following typedefs for common character types: + + +---------+----------------------+ + | Type | Definition | + +=========+======================+ + | Writer | BasicWriter | + +---------+----------------------+ + | WWriter | BasicWriter | + +---------+----------------------+ + + \endrst + */ +template +class BasicWriter { + private: + // Output buffer. + Buffer &buffer_; + + FMT_DISALLOW_COPY_AND_ASSIGN(BasicWriter); + + typedef typename internal::CharTraits::CharPtr CharPtr; + +#if FMT_SECURE_SCL + // Returns pointer value. + static Char *get(CharPtr p) { return p.base(); } +#else + static Char *get(Char *p) { return p; } +#endif + + // Fills the padding around the content and returns the pointer to the + // content area. + static CharPtr fill_padding(CharPtr buffer, + unsigned total_size, std::size_t content_size, wchar_t fill); + + // Grows the buffer by n characters and returns a pointer to the newly + // allocated area. + CharPtr grow_buffer(std::size_t n) { + std::size_t size = buffer_.size(); + buffer_.resize(size + n); + return internal::make_ptr(&buffer_[size], n); + } + + // Writes an unsigned decimal integer. + template + Char *write_unsigned_decimal(UInt value, unsigned prefix_size = 0) { + unsigned num_digits = internal::count_digits(value); + Char *ptr = get(grow_buffer(prefix_size + num_digits)); + internal::format_decimal(ptr + prefix_size, value, num_digits); + return ptr; + } + + // Writes a decimal integer. + template + void write_decimal(Int value) { + typedef typename internal::IntTraits::MainType MainType; + MainType abs_value = static_cast(value); + if (internal::is_negative(value)) { + abs_value = 0 - abs_value; + *write_unsigned_decimal(abs_value, 1) = '-'; + } else { + write_unsigned_decimal(abs_value, 0); + } + } + + // Prepare a buffer for integer formatting. + CharPtr prepare_int_buffer(unsigned num_digits, + const EmptySpec &, const char *prefix, unsigned prefix_size) { + unsigned size = prefix_size + num_digits; + CharPtr p = grow_buffer(size); + std::uninitialized_copy(prefix, prefix + prefix_size, p); + return p + size - 1; + } + + template + CharPtr prepare_int_buffer(unsigned num_digits, + const Spec &spec, const char *prefix, unsigned prefix_size); + + // Formats an integer. + template + void write_int(T value, Spec spec); + + // Formats a floating-point number (double or long double). + template + void write_double(T value, const FormatSpec &spec); + + // Writes a formatted string. + template + CharPtr write_str(const StrChar *s, std::size_t size, const AlignSpec &spec); + + template + void write_str(const internal::Arg::StringValue &str, + const FormatSpec &spec); + + // This following methods are private to disallow writing wide characters + // and strings to a char stream. If you want to print a wide string as a + // pointer as std::ostream does, cast it to const void*. + // Do not implement! + void operator<<(typename internal::WCharHelper::Unsupported); + void operator<<( + typename internal::WCharHelper::Unsupported); + + // Appends floating-point length specifier to the format string. + // The second argument is only used for overload resolution. + void append_float_length(Char *&format_ptr, long double) { + *format_ptr++ = 'L'; + } + + template + void append_float_length(Char *&, T) {} + + template + friend class internal::ArgFormatterBase; + + friend class internal::PrintfArgFormatter; + + protected: + /** + Constructs a ``BasicWriter`` object. + */ + explicit BasicWriter(Buffer &b) : buffer_(b) {} + + public: + /** + \rst + Destroys a ``BasicWriter`` object. + \endrst + */ + virtual ~BasicWriter() {} + + /** + Returns the total number of characters written. + */ + std::size_t size() const { return buffer_.size(); } + + /** + Returns a pointer to the output buffer content. No terminating null + character is appended. + */ + const Char *data() const FMT_NOEXCEPT { return &buffer_[0]; } + + /** + Returns a pointer to the output buffer content with terminating null + character appended. + */ + const Char *c_str() const { + std::size_t size = buffer_.size(); + buffer_.reserve(size + 1); + buffer_[size] = '\0'; + return &buffer_[0]; + } + + /** + \rst + Returns the content of the output buffer as an `std::string`. + \endrst + */ + std::basic_string str() const { + return std::basic_string(&buffer_[0], buffer_.size()); + } + + /** + \rst + Writes formatted data. + + *args* is an argument list representing arbitrary arguments. + + **Example**:: + + MemoryWriter out; + out.write("Current point:\n"); + out.write("({:+f}, {:+f})", -3.14, 3.14); + + This will write the following output to the ``out`` object: + + .. code-block:: none + + Current point: + (-3.140000, +3.140000) + + The output can be accessed using :func:`data()`, :func:`c_str` or + :func:`str` methods. + + See also :ref:`syntax`. + \endrst + */ + void write(BasicCStringRef format, ArgList args) { + BasicFormatter(args, *this).format(format); + } + FMT_VARIADIC_VOID(write, BasicCStringRef) + + BasicWriter &operator<<(int value) { + write_decimal(value); + return *this; + } + BasicWriter &operator<<(unsigned value) { + return *this << IntFormatSpec(value); + } + BasicWriter &operator<<(long value) { + write_decimal(value); + return *this; + } + BasicWriter &operator<<(unsigned long value) { + return *this << IntFormatSpec(value); + } + BasicWriter &operator<<(LongLong value) { + write_decimal(value); + return *this; + } + + /** + \rst + Formats *value* and writes it to the stream. + \endrst + */ + BasicWriter &operator<<(ULongLong value) { + return *this << IntFormatSpec(value); + } + + BasicWriter &operator<<(double value) { + write_double(value, FormatSpec()); + return *this; + } + + /** + \rst + Formats *value* using the general format for floating-point numbers + (``'g'``) and writes it to the stream. + \endrst + */ + BasicWriter &operator<<(long double value) { + write_double(value, FormatSpec()); + return *this; + } + + /** + Writes a character to the stream. + */ + BasicWriter &operator<<(char value) { + buffer_.push_back(value); + return *this; + } + + BasicWriter &operator<<( + typename internal::WCharHelper::Supported value) { + buffer_.push_back(value); + return *this; + } + + /** + \rst + Writes *value* to the stream. + \endrst + */ + BasicWriter &operator<<(fmt::BasicStringRef value) { + const Char *str = value.data(); + buffer_.append(str, str + value.size()); + return *this; + } + + BasicWriter &operator<<( + typename internal::WCharHelper::Supported value) { + const char *str = value.data(); + buffer_.append(str, str + value.size()); + return *this; + } + + template + BasicWriter &operator<<(IntFormatSpec spec) { + internal::CharTraits::convert(FillChar()); + write_int(spec.value(), spec); + return *this; + } + + template + BasicWriter &operator<<(const StrFormatSpec &spec) { + const StrChar *s = spec.str(); + write_str(s, std::char_traits::length(s), spec); + return *this; + } + + void clear() FMT_NOEXCEPT { buffer_.clear(); } + + Buffer &buffer() FMT_NOEXCEPT { return buffer_; } +}; + +template +template +typename BasicWriter::CharPtr BasicWriter::write_str( + const StrChar *s, std::size_t size, const AlignSpec &spec) { + CharPtr out = CharPtr(); + if (spec.width() > size) { + out = grow_buffer(spec.width()); + Char fill = internal::CharTraits::cast(spec.fill()); + if (spec.align() == ALIGN_RIGHT) { + std::uninitialized_fill_n(out, spec.width() - size, fill); + out += spec.width() - size; + } else if (spec.align() == ALIGN_CENTER) { + out = fill_padding(out, spec.width(), size, fill); + } else { + std::uninitialized_fill_n(out + size, spec.width() - size, fill); + } + } else { + out = grow_buffer(size); + } + std::uninitialized_copy(s, s + size, out); + return out; +} + +template +template +void BasicWriter::write_str( + const internal::Arg::StringValue &s, const FormatSpec &spec) { + // Check if StrChar is convertible to Char. + internal::CharTraits::convert(StrChar()); + if (spec.type_ && spec.type_ != 's') + internal::report_unknown_type(spec.type_, "string"); + const StrChar *str_value = s.value; + std::size_t str_size = s.size; + if (str_size == 0) { + if (!str_value) { + FMT_THROW(FormatError("string pointer is null")); + return; + } + } + std::size_t precision = static_cast(spec.precision_); + if (spec.precision_ >= 0 && precision < str_size) + str_size = precision; + write_str(str_value, str_size, spec); +} + +template +typename BasicWriter::CharPtr + BasicWriter::fill_padding( + CharPtr buffer, unsigned total_size, + std::size_t content_size, wchar_t fill) { + std::size_t padding = total_size - content_size; + std::size_t left_padding = padding / 2; + Char fill_char = internal::CharTraits::cast(fill); + std::uninitialized_fill_n(buffer, left_padding, fill_char); + buffer += left_padding; + CharPtr content = buffer; + std::uninitialized_fill_n(buffer + content_size, + padding - left_padding, fill_char); + return content; +} + +template +template +typename BasicWriter::CharPtr + BasicWriter::prepare_int_buffer( + unsigned num_digits, const Spec &spec, + const char *prefix, unsigned prefix_size) { + unsigned width = spec.width(); + Alignment align = spec.align(); + Char fill = internal::CharTraits::cast(spec.fill()); + if (spec.precision() > static_cast(num_digits)) { + // Octal prefix '0' is counted as a digit, so ignore it if precision + // is specified. + if (prefix_size > 0 && prefix[prefix_size - 1] == '0') + --prefix_size; + unsigned number_size = + prefix_size + internal::to_unsigned(spec.precision()); + AlignSpec subspec(number_size, '0', ALIGN_NUMERIC); + if (number_size >= width) + return prepare_int_buffer(num_digits, subspec, prefix, prefix_size); + buffer_.reserve(width); + unsigned fill_size = width - number_size; + if (align != ALIGN_LEFT) { + CharPtr p = grow_buffer(fill_size); + std::uninitialized_fill(p, p + fill_size, fill); + } + CharPtr result = prepare_int_buffer( + num_digits, subspec, prefix, prefix_size); + if (align == ALIGN_LEFT) { + CharPtr p = grow_buffer(fill_size); + std::uninitialized_fill(p, p + fill_size, fill); + } + return result; + } + unsigned size = prefix_size + num_digits; + if (width <= size) { + CharPtr p = grow_buffer(size); + std::uninitialized_copy(prefix, prefix + prefix_size, p); + return p + size - 1; + } + CharPtr p = grow_buffer(width); + CharPtr end = p + width; + if (align == ALIGN_LEFT) { + std::uninitialized_copy(prefix, prefix + prefix_size, p); + p += size; + std::uninitialized_fill(p, end, fill); + } else if (align == ALIGN_CENTER) { + p = fill_padding(p, width, size, fill); + std::uninitialized_copy(prefix, prefix + prefix_size, p); + p += size; + } else { + if (align == ALIGN_NUMERIC) { + if (prefix_size != 0) { + p = std::uninitialized_copy(prefix, prefix + prefix_size, p); + size -= prefix_size; + } + } else { + std::uninitialized_copy(prefix, prefix + prefix_size, end - size); + } + std::uninitialized_fill(p, end - size, fill); + p = end; + } + return p - 1; +} + +template +template +void BasicWriter::write_int(T value, Spec spec) { + unsigned prefix_size = 0; + typedef typename internal::IntTraits::MainType UnsignedType; + UnsignedType abs_value = static_cast(value); + char prefix[4] = ""; + if (internal::is_negative(value)) { + prefix[0] = '-'; + ++prefix_size; + abs_value = 0 - abs_value; + } else if (spec.flag(SIGN_FLAG)) { + prefix[0] = spec.flag(PLUS_FLAG) ? '+' : ' '; + ++prefix_size; + } + switch (spec.type()) { + case 0: case 'd': { + unsigned num_digits = internal::count_digits(abs_value); + CharPtr p = prepare_int_buffer(num_digits, spec, prefix, prefix_size) + 1; + internal::format_decimal(get(p), abs_value, 0); + break; + } + case 'x': case 'X': { + UnsignedType n = abs_value; + if (spec.flag(HASH_FLAG)) { + prefix[prefix_size++] = '0'; + prefix[prefix_size++] = spec.type(); + } + unsigned num_digits = 0; + do { + ++num_digits; + } while ((n >>= 4) != 0); + Char *p = get(prepare_int_buffer( + num_digits, spec, prefix, prefix_size)); + n = abs_value; + const char *digits = spec.type() == 'x' ? + "0123456789abcdef" : "0123456789ABCDEF"; + do { + *p-- = digits[n & 0xf]; + } while ((n >>= 4) != 0); + break; + } + case 'b': case 'B': { + UnsignedType n = abs_value; + if (spec.flag(HASH_FLAG)) { + prefix[prefix_size++] = '0'; + prefix[prefix_size++] = spec.type(); + } + unsigned num_digits = 0; + do { + ++num_digits; + } while ((n >>= 1) != 0); + Char *p = get(prepare_int_buffer(num_digits, spec, prefix, prefix_size)); + n = abs_value; + do { + *p-- = static_cast('0' + (n & 1)); + } while ((n >>= 1) != 0); + break; + } + case 'o': { + UnsignedType n = abs_value; + if (spec.flag(HASH_FLAG)) + prefix[prefix_size++] = '0'; + unsigned num_digits = 0; + do { + ++num_digits; + } while ((n >>= 3) != 0); + Char *p = get(prepare_int_buffer(num_digits, spec, prefix, prefix_size)); + n = abs_value; + do { + *p-- = static_cast('0' + (n & 7)); + } while ((n >>= 3) != 0); + break; + } + case 'n': { + unsigned num_digits = internal::count_digits(abs_value); + fmt::StringRef sep = std::localeconv()->thousands_sep; + unsigned size = static_cast( + num_digits + sep.size() * (num_digits - 1) / 3); + CharPtr p = prepare_int_buffer(size, spec, prefix, prefix_size) + 1; + internal::format_decimal(get(p), abs_value, 0, internal::ThousandsSep(sep)); + break; + } + default: + internal::report_unknown_type( + spec.type(), spec.flag(CHAR_FLAG) ? "char" : "integer"); + break; + } +} + +template +template +void BasicWriter::write_double(T value, const FormatSpec &spec) { + // Check type. + char type = spec.type(); + bool upper = false; + switch (type) { + case 0: + type = 'g'; + break; + case 'e': case 'f': case 'g': case 'a': + break; + case 'F': +#ifdef _MSC_VER + // MSVC's printf doesn't support 'F'. + type = 'f'; +#endif + // Fall through. + case 'E': case 'G': case 'A': + upper = true; + break; + default: + internal::report_unknown_type(type, "double"); + break; + } + + char sign = 0; + // Use isnegative instead of value < 0 because the latter is always + // false for NaN. + if (internal::FPUtil::isnegative(static_cast(value))) { + sign = '-'; + value = -value; + } else if (spec.flag(SIGN_FLAG)) { + sign = spec.flag(PLUS_FLAG) ? '+' : ' '; + } + + if (internal::FPUtil::isnotanumber(value)) { + // Format NaN ourselves because sprintf's output is not consistent + // across platforms. + std::size_t nan_size = 4; + const char *nan = upper ? " NAN" : " nan"; + if (!sign) { + --nan_size; + ++nan; + } + CharPtr out = write_str(nan, nan_size, spec); + if (sign) + *out = sign; + return; + } + + if (internal::FPUtil::isinfinity(value)) { + // Format infinity ourselves because sprintf's output is not consistent + // across platforms. + std::size_t inf_size = 4; + const char *inf = upper ? " INF" : " inf"; + if (!sign) { + --inf_size; + ++inf; + } + CharPtr out = write_str(inf, inf_size, spec); + if (sign) + *out = sign; + return; + } + + std::size_t offset = buffer_.size(); + unsigned width = spec.width(); + if (sign) { + buffer_.reserve(buffer_.size() + (width > 1u ? width : 1u)); + if (width > 0) + --width; + ++offset; + } + + // Build format string. + enum { MAX_FORMAT_SIZE = 10}; // longest format: %#-*.*Lg + Char format[MAX_FORMAT_SIZE]; + Char *format_ptr = format; + *format_ptr++ = '%'; + unsigned width_for_sprintf = width; + if (spec.flag(HASH_FLAG)) + *format_ptr++ = '#'; + if (spec.align() == ALIGN_CENTER) { + width_for_sprintf = 0; + } else { + if (spec.align() == ALIGN_LEFT) + *format_ptr++ = '-'; + if (width != 0) + *format_ptr++ = '*'; + } + if (spec.precision() >= 0) { + *format_ptr++ = '.'; + *format_ptr++ = '*'; + } + + append_float_length(format_ptr, value); + *format_ptr++ = type; + *format_ptr = '\0'; + + // Format using snprintf. + Char fill = internal::CharTraits::cast(spec.fill()); + unsigned n = 0; + Char *start = 0; + for (;;) { + std::size_t buffer_size = buffer_.capacity() - offset; +#ifdef _MSC_VER + // MSVC's vsnprintf_s doesn't work with zero size, so reserve + // space for at least one extra character to make the size non-zero. + // Note that the buffer's capacity will increase by more than 1. + if (buffer_size == 0) { + buffer_.reserve(offset + 1); + buffer_size = buffer_.capacity() - offset; + } +#endif + start = &buffer_[offset]; + int result = internal::CharTraits::format_float( + start, buffer_size, format, width_for_sprintf, spec.precision(), value); + if (result >= 0) { + n = internal::to_unsigned(result); + if (offset + n < buffer_.capacity()) + break; // The buffer is large enough - continue with formatting. + buffer_.reserve(offset + n + 1); + } else { + // If result is negative we ask to increase the capacity by at least 1, + // but as std::vector, the buffer grows exponentially. + buffer_.reserve(buffer_.capacity() + 1); + } + } + if (sign) { + if ((spec.align() != ALIGN_RIGHT && spec.align() != ALIGN_DEFAULT) || + *start != ' ') { + *(start - 1) = sign; + sign = 0; + } else { + *(start - 1) = fill; + } + ++n; + } + if (spec.align() == ALIGN_CENTER && spec.width() > n) { + width = spec.width(); + CharPtr p = grow_buffer(width); + std::memmove(get(p) + (width - n) / 2, get(p), n * sizeof(Char)); + fill_padding(p, spec.width(), n, fill); + return; + } + if (spec.fill() != ' ' || sign) { + while (*start == ' ') + *start++ = fill; + if (sign) + *(start - 1) = sign; + } + grow_buffer(n); +} + +/** + \rst + This class template provides operations for formatting and writing data + into a character stream. The output is stored in a memory buffer that grows + dynamically. + + You can use one of the following typedefs for common character types + and the standard allocator: + + +---------------+-----------------------------------------------------+ + | Type | Definition | + +===============+=====================================================+ + | MemoryWriter | BasicMemoryWriter> | + +---------------+-----------------------------------------------------+ + | WMemoryWriter | BasicMemoryWriter> | + +---------------+-----------------------------------------------------+ + + **Example**:: + + MemoryWriter out; + out << "The answer is " << 42 << "\n"; + out.write("({:+f}, {:+f})", -3.14, 3.14); + + This will write the following output to the ``out`` object: + + .. code-block:: none + + The answer is 42 + (-3.140000, +3.140000) + + The output can be converted to an ``std::string`` with ``out.str()`` or + accessed as a C string with ``out.c_str()``. + \endrst + */ +template > +class BasicMemoryWriter : public BasicWriter { + private: + internal::MemoryBuffer buffer_; + + public: + explicit BasicMemoryWriter(const Allocator& alloc = Allocator()) + : BasicWriter(buffer_), buffer_(alloc) {} + +#if FMT_USE_RVALUE_REFERENCES + /** + \rst + Constructs a :class:`fmt::BasicMemoryWriter` object moving the content + of the other object to it. + \endrst + */ + BasicMemoryWriter(BasicMemoryWriter &&other) + : BasicWriter(buffer_), buffer_(std::move(other.buffer_)) { + } + + /** + \rst + Moves the content of the other ``BasicMemoryWriter`` object to this one. + \endrst + */ + BasicMemoryWriter &operator=(BasicMemoryWriter &&other) { + buffer_ = std::move(other.buffer_); + return *this; + } +#endif +}; + +typedef BasicMemoryWriter MemoryWriter; +typedef BasicMemoryWriter WMemoryWriter; + +/** + \rst + This class template provides operations for formatting and writing data + into a fixed-size array. For writing into a dynamically growing buffer + use :class:`fmt::BasicMemoryWriter`. + + Any write method will throw ``std::runtime_error`` if the output doesn't fit + into the array. + + You can use one of the following typedefs for common character types: + + +--------------+---------------------------+ + | Type | Definition | + +==============+===========================+ + | ArrayWriter | BasicArrayWriter | + +--------------+---------------------------+ + | WArrayWriter | BasicArrayWriter | + +--------------+---------------------------+ + \endrst + */ +template +class BasicArrayWriter : public BasicWriter { + private: + internal::FixedBuffer buffer_; + + public: + /** + \rst + Constructs a :class:`fmt::BasicArrayWriter` object for *array* of the + given size. + \endrst + */ + BasicArrayWriter(Char *array, std::size_t size) + : BasicWriter(buffer_), buffer_(array, size) {} + + /** + \rst + Constructs a :class:`fmt::BasicArrayWriter` object for *array* of the + size known at compile time. + \endrst + */ + template + explicit BasicArrayWriter(Char (&array)[SIZE]) + : BasicWriter(buffer_), buffer_(array, SIZE) {} +}; + +typedef BasicArrayWriter ArrayWriter; +typedef BasicArrayWriter WArrayWriter; + +// Reports a system error without throwing an exception. +// Can be used to report errors from destructors. +FMT_API void report_system_error(int error_code, + StringRef message) FMT_NOEXCEPT; + +#if FMT_USE_WINDOWS_H + +/** A Windows error. */ +class WindowsError : public SystemError { + private: + FMT_API void init(int error_code, CStringRef format_str, ArgList args); + + public: + /** + \rst + Constructs a :class:`fmt::WindowsError` object with the description + of the form + + .. parsed-literal:: + **: ** + + where ** is the formatted message and ** is the + system message corresponding to the error code. + *error_code* is a Windows error code as given by ``GetLastError``. + If *error_code* is not a valid error code such as -1, the system message + will look like "error -1". + + **Example**:: + + // This throws a WindowsError with the description + // cannot open file 'madeup': The system cannot find the file specified. + // or similar (system message may vary). + const char *filename = "madeup"; + LPOFSTRUCT of = LPOFSTRUCT(); + HFILE file = OpenFile(filename, &of, OF_READ); + if (file == HFILE_ERROR) { + throw fmt::WindowsError(GetLastError(), + "cannot open file '{}'", filename); + } + \endrst + */ + WindowsError(int error_code, CStringRef message) { + init(error_code, message, ArgList()); + } + FMT_VARIADIC_CTOR(WindowsError, init, int, CStringRef) +}; + +// Reports a Windows error without throwing an exception. +// Can be used to report errors from destructors. +FMT_API void report_windows_error(int error_code, + StringRef message) FMT_NOEXCEPT; + +#endif + +enum Color { BLACK, RED, GREEN, YELLOW, BLUE, MAGENTA, CYAN, WHITE }; + +/** + Formats a string and prints it to stdout using ANSI escape sequences + to specify color (experimental). + Example: + print_colored(fmt::RED, "Elapsed time: {0:.2f} seconds", 1.23); + */ +FMT_API void print_colored(Color c, CStringRef format, ArgList args); + +/** + \rst + Formats arguments and returns the result as a string. + + **Example**:: + + std::string message = format("The answer is {}", 42); + \endrst +*/ +inline std::string format(CStringRef format_str, ArgList args) { + MemoryWriter w; + w.write(format_str, args); + return w.str(); +} + +inline std::wstring format(WCStringRef format_str, ArgList args) { + WMemoryWriter w; + w.write(format_str, args); + return w.str(); +} + +/** + \rst + Prints formatted data to the file *f*. + + **Example**:: + + print(stderr, "Don't {}!", "panic"); + \endrst + */ +FMT_API void print(std::FILE *f, CStringRef format_str, ArgList args); + +/** + \rst + Prints formatted data to ``stdout``. + + **Example**:: + + print("Elapsed time: {0:.2f} seconds", 1.23); + \endrst + */ +FMT_API void print(CStringRef format_str, ArgList args); + +template +void printf(BasicWriter &w, BasicCStringRef format, ArgList args) { + internal::PrintfFormatter(args).format(w, format); +} + +/** + \rst + Formats arguments and returns the result as a string. + + **Example**:: + + std::string message = fmt::sprintf("The answer is %d", 42); + \endrst +*/ +inline std::string sprintf(CStringRef format, ArgList args) { + MemoryWriter w; + printf(w, format, args); + return w.str(); +} + +inline std::wstring sprintf(WCStringRef format, ArgList args) { + WMemoryWriter w; + printf(w, format, args); + return w.str(); +} + +/** + \rst + Prints formatted data to the file *f*. + + **Example**:: + + fmt::fprintf(stderr, "Don't %s!", "panic"); + \endrst + */ +FMT_API int fprintf(std::FILE *f, CStringRef format, ArgList args); + +/** + \rst + Prints formatted data to ``stdout``. + + **Example**:: + + fmt::printf("Elapsed time: %.2f seconds", 1.23); + \endrst + */ +inline int printf(CStringRef format, ArgList args) { + return fprintf(stdout, format, args); +} + +/** + Fast integer formatter. + */ +class FormatInt { + private: + // Buffer should be large enough to hold all digits (digits10 + 1), + // a sign and a null character. + enum {BUFFER_SIZE = std::numeric_limits::digits10 + 3}; + mutable char buffer_[BUFFER_SIZE]; + char *str_; + + // Formats value in reverse and returns the number of digits. + char *format_decimal(ULongLong value) { + char *buffer_end = buffer_ + BUFFER_SIZE - 1; + while (value >= 100) { + // Integer division is slow so do it for a group of two digits instead + // of for every digit. The idea comes from the talk by Alexandrescu + // "Three Optimization Tips for C++". See speed-test for a comparison. + unsigned index = static_cast((value % 100) * 2); + value /= 100; + *--buffer_end = internal::Data::DIGITS[index + 1]; + *--buffer_end = internal::Data::DIGITS[index]; + } + if (value < 10) { + *--buffer_end = static_cast('0' + value); + return buffer_end; + } + unsigned index = static_cast(value * 2); + *--buffer_end = internal::Data::DIGITS[index + 1]; + *--buffer_end = internal::Data::DIGITS[index]; + return buffer_end; + } + + void FormatSigned(LongLong value) { + ULongLong abs_value = static_cast(value); + bool negative = value < 0; + if (negative) + abs_value = 0 - abs_value; + str_ = format_decimal(abs_value); + if (negative) + *--str_ = '-'; + } + + public: + explicit FormatInt(int value) { FormatSigned(value); } + explicit FormatInt(long value) { FormatSigned(value); } + explicit FormatInt(LongLong value) { FormatSigned(value); } + explicit FormatInt(unsigned value) : str_(format_decimal(value)) {} + explicit FormatInt(unsigned long value) : str_(format_decimal(value)) {} + explicit FormatInt(ULongLong value) : str_(format_decimal(value)) {} + + /** Returns the number of characters written to the output buffer. */ + std::size_t size() const { + return internal::to_unsigned(buffer_ - str_ + BUFFER_SIZE - 1); + } + + /** + Returns a pointer to the output buffer content. No terminating null + character is appended. + */ + const char *data() const { return str_; } + + /** + Returns a pointer to the output buffer content with terminating null + character appended. + */ + const char *c_str() const { + buffer_[BUFFER_SIZE - 1] = '\0'; + return str_; + } + + /** + \rst + Returns the content of the output buffer as an ``std::string``. + \endrst + */ + std::string str() const { return std::string(str_, size()); } +}; + +// Formats a decimal integer value writing into buffer and returns +// a pointer to the end of the formatted string. This function doesn't +// write a terminating null character. +template +inline void format_decimal(char *&buffer, T value) { + typedef typename internal::IntTraits::MainType MainType; + MainType abs_value = static_cast(value); + if (internal::is_negative(value)) { + *buffer++ = '-'; + abs_value = 0 - abs_value; + } + if (abs_value < 100) { + if (abs_value < 10) { + *buffer++ = static_cast('0' + abs_value); + return; + } + unsigned index = static_cast(abs_value * 2); + *buffer++ = internal::Data::DIGITS[index]; + *buffer++ = internal::Data::DIGITS[index + 1]; + return; + } + unsigned num_digits = internal::count_digits(abs_value); + internal::format_decimal(buffer, abs_value, num_digits); + buffer += num_digits; +} + +/** + \rst + Returns a named argument for formatting functions. + + **Example**:: + + print("Elapsed time: {s:.2f} seconds", arg("s", 1.23)); + + \endrst + */ +template +inline internal::NamedArg arg(StringRef name, const T &arg) { + return internal::NamedArg(name, arg); +} + +template +inline internal::NamedArg arg(WStringRef name, const T &arg) { + return internal::NamedArg(name, arg); +} + +// The following two functions are deleted intentionally to disable +// nested named arguments as in ``format("{}", arg("a", arg("b", 42)))``. +template +void arg(StringRef, const internal::NamedArg&) FMT_DELETED_OR_UNDEFINED; +template +void arg(WStringRef, const internal::NamedArg&) FMT_DELETED_OR_UNDEFINED; +} + +#if FMT_GCC_VERSION +// Use the system_header pragma to suppress warnings about variadic macros +// because suppressing -Wvariadic-macros with the diagnostic pragma doesn't +// work. It is used at the end because we want to suppress as little warnings +// as possible. +# pragma GCC system_header +#endif + +// This is used to work around VC++ bugs in handling variadic macros. +#define FMT_EXPAND(args) args + +// Returns the number of arguments. +// Based on https://groups.google.com/forum/#!topic/comp.std.c/d-6Mj5Lko_s. +#define FMT_NARG(...) FMT_NARG_(__VA_ARGS__, FMT_RSEQ_N()) +#define FMT_NARG_(...) FMT_EXPAND(FMT_ARG_N(__VA_ARGS__)) +#define FMT_ARG_N(_1, _2, _3, _4, _5, _6, _7, _8, _9, _10, N, ...) N +#define FMT_RSEQ_N() 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0 + +#define FMT_CONCAT(a, b) a##b +#define FMT_FOR_EACH_(N, f, ...) \ + FMT_EXPAND(FMT_CONCAT(FMT_FOR_EACH, N)(f, __VA_ARGS__)) +#define FMT_FOR_EACH(f, ...) \ + FMT_EXPAND(FMT_FOR_EACH_(FMT_NARG(__VA_ARGS__), f, __VA_ARGS__)) + +#define FMT_ADD_ARG_NAME(type, index) type arg##index +#define FMT_GET_ARG_NAME(type, index) arg##index + +#if FMT_USE_VARIADIC_TEMPLATES +# define FMT_VARIADIC_(Char, ReturnType, func, call, ...) \ + template \ + ReturnType func(FMT_FOR_EACH(FMT_ADD_ARG_NAME, __VA_ARGS__), \ + const Args & ... args) { \ + typedef fmt::internal::ArgArray ArgArray; \ + typename ArgArray::Type array{ \ + ArgArray::template make >(args)...}; \ + call(FMT_FOR_EACH(FMT_GET_ARG_NAME, __VA_ARGS__), \ + fmt::ArgList(fmt::internal::make_type(args...), array)); \ + } +#else +// Defines a wrapper for a function taking __VA_ARGS__ arguments +// and n additional arguments of arbitrary types. +# define FMT_WRAP(Char, ReturnType, func, call, n, ...) \ + template \ + inline ReturnType func(FMT_FOR_EACH(FMT_ADD_ARG_NAME, __VA_ARGS__), \ + FMT_GEN(n, FMT_MAKE_ARG)) { \ + fmt::internal::ArgArray::Type arr; \ + FMT_GEN(n, FMT_ASSIGN_##Char); \ + call(FMT_FOR_EACH(FMT_GET_ARG_NAME, __VA_ARGS__), fmt::ArgList( \ + fmt::internal::make_type(FMT_GEN(n, FMT_MAKE_REF2)), arr)); \ + } + +# define FMT_VARIADIC_(Char, ReturnType, func, call, ...) \ + inline ReturnType func(FMT_FOR_EACH(FMT_ADD_ARG_NAME, __VA_ARGS__)) { \ + call(FMT_FOR_EACH(FMT_GET_ARG_NAME, __VA_ARGS__), fmt::ArgList()); \ + } \ + FMT_WRAP(Char, ReturnType, func, call, 1, __VA_ARGS__) \ + FMT_WRAP(Char, ReturnType, func, call, 2, __VA_ARGS__) \ + FMT_WRAP(Char, ReturnType, func, call, 3, __VA_ARGS__) \ + FMT_WRAP(Char, ReturnType, func, call, 4, __VA_ARGS__) \ + FMT_WRAP(Char, ReturnType, func, call, 5, __VA_ARGS__) \ + FMT_WRAP(Char, ReturnType, func, call, 6, __VA_ARGS__) \ + FMT_WRAP(Char, ReturnType, func, call, 7, __VA_ARGS__) \ + FMT_WRAP(Char, ReturnType, func, call, 8, __VA_ARGS__) \ + FMT_WRAP(Char, ReturnType, func, call, 9, __VA_ARGS__) \ + FMT_WRAP(Char, ReturnType, func, call, 10, __VA_ARGS__) \ + FMT_WRAP(Char, ReturnType, func, call, 11, __VA_ARGS__) \ + FMT_WRAP(Char, ReturnType, func, call, 12, __VA_ARGS__) \ + FMT_WRAP(Char, ReturnType, func, call, 13, __VA_ARGS__) \ + FMT_WRAP(Char, ReturnType, func, call, 14, __VA_ARGS__) \ + FMT_WRAP(Char, ReturnType, func, call, 15, __VA_ARGS__) +#endif // FMT_USE_VARIADIC_TEMPLATES + +/** + \rst + Defines a variadic function with the specified return type, function name + and argument types passed as variable arguments to this macro. + + **Example**:: + + void print_error(const char *file, int line, const char *format, + fmt::ArgList args) { + fmt::print("{}: {}: ", file, line); + fmt::print(format, args); + } + FMT_VARIADIC(void, print_error, const char *, int, const char *) + + ``FMT_VARIADIC`` is used for compatibility with legacy C++ compilers that + don't implement variadic templates. You don't have to use this macro if + you don't need legacy compiler support and can use variadic templates + directly:: + + template + void print_error(const char *file, int line, const char *format, + const Args & ... args) { + fmt::print("{}: {}: ", file, line); + fmt::print(format, args...); + } + \endrst + */ +#define FMT_VARIADIC(ReturnType, func, ...) \ + FMT_VARIADIC_(char, ReturnType, func, return func, __VA_ARGS__) + +#define FMT_VARIADIC_W(ReturnType, func, ...) \ + FMT_VARIADIC_(wchar_t, ReturnType, func, return func, __VA_ARGS__) + +#define FMT_CAPTURE_ARG_(id, index) ::fmt::arg(#id, id) + +#define FMT_CAPTURE_ARG_W_(id, index) ::fmt::arg(L###id, id) + +/** + \rst + Convenient macro to capture the arguments' names and values into several + ``fmt::arg(name, value)``. + + **Example**:: + + int x = 1, y = 2; + print("point: ({x}, {y})", FMT_CAPTURE(x, y)); + // same as: + // print("point: ({x}, {y})", arg("x", x), arg("y", y)); + + \endrst + */ +#define FMT_CAPTURE(...) FMT_FOR_EACH(FMT_CAPTURE_ARG_, __VA_ARGS__) + +#define FMT_CAPTURE_W(...) FMT_FOR_EACH(FMT_CAPTURE_ARG_W_, __VA_ARGS__) + +namespace fmt { +FMT_VARIADIC(std::string, format, CStringRef) +FMT_VARIADIC_W(std::wstring, format, WCStringRef) +FMT_VARIADIC(void, print, CStringRef) +FMT_VARIADIC(void, print, std::FILE *, CStringRef) + +FMT_VARIADIC(void, print_colored, Color, CStringRef) +FMT_VARIADIC(std::string, sprintf, CStringRef) +FMT_VARIADIC_W(std::wstring, sprintf, WCStringRef) +FMT_VARIADIC(int, printf, CStringRef) +FMT_VARIADIC(int, fprintf, std::FILE *, CStringRef) + +namespace internal { +template +inline bool is_name_start(Char c) { + return ('a' <= c && c <= 'z') || ('A' <= c && c <= 'Z') || '_' == c; +} + +// Parses an unsigned integer advancing s to the end of the parsed input. +// This function assumes that the first character of s is a digit. +template +unsigned parse_nonnegative_int(const Char *&s) { + assert('0' <= *s && *s <= '9'); + unsigned value = 0; + do { + unsigned new_value = value * 10 + (*s++ - '0'); + // Check if value wrapped around. + if (new_value < value) { + value = (std::numeric_limits::max)(); + break; + } + value = new_value; + } while ('0' <= *s && *s <= '9'); + // Convert to unsigned to prevent a warning. + unsigned max_int = (std::numeric_limits::max)(); + if (value > max_int) + FMT_THROW(FormatError("number is too big")); + return value; +} + +inline void require_numeric_argument(const Arg &arg, char spec) { + if (arg.type > Arg::LAST_NUMERIC_TYPE) { + std::string message = + fmt::format("format specifier '{}' requires numeric argument", spec); + FMT_THROW(fmt::FormatError(message)); + } +} + +template +void check_sign(const Char *&s, const Arg &arg) { + char sign = static_cast(*s); + require_numeric_argument(arg, sign); + if (arg.type == Arg::UINT || arg.type == Arg::ULONG_LONG) { + FMT_THROW(FormatError(fmt::format( + "format specifier '{}' requires signed argument", sign))); + } + ++s; +} +} // namespace internal + +template +inline internal::Arg BasicFormatter::get_arg( + BasicStringRef arg_name, const char *&error) { + if (check_no_auto_index(error)) { + map_.init(args()); + const internal::Arg *arg = map_.find(arg_name); + if (arg) + return *arg; + error = "argument not found"; + } + return internal::Arg(); +} + +template +inline internal::Arg BasicFormatter::parse_arg_index(const Char *&s) { + const char *error = 0; + internal::Arg arg = *s < '0' || *s > '9' ? + next_arg(error) : get_arg(internal::parse_nonnegative_int(s), error); + if (error) { + FMT_THROW(FormatError( + *s != '}' && *s != ':' ? "invalid format string" : error)); + } + return arg; +} + +template +inline internal::Arg BasicFormatter::parse_arg_name(const Char *&s) { + assert(internal::is_name_start(*s)); + const Char *start = s; + Char c; + do { + c = *++s; + } while (internal::is_name_start(c) || ('0' <= c && c <= '9')); + const char *error = 0; + internal::Arg arg = get_arg(BasicStringRef(start, s - start), error); + if (error) + FMT_THROW(FormatError(error)); + return arg; +} + +template +const Char *BasicFormatter::format( + const Char *&format_str, const internal::Arg &arg) { + using internal::Arg; + const Char *s = format_str; + FormatSpec spec; + if (*s == ':') { + if (arg.type == Arg::CUSTOM) { + arg.custom.format(this, arg.custom.value, &s); + return s; + } + ++s; + // Parse fill and alignment. + if (Char c = *s) { + const Char *p = s + 1; + spec.align_ = ALIGN_DEFAULT; + do { + switch (*p) { + case '<': + spec.align_ = ALIGN_LEFT; + break; + case '>': + spec.align_ = ALIGN_RIGHT; + break; + case '=': + spec.align_ = ALIGN_NUMERIC; + break; + case '^': + spec.align_ = ALIGN_CENTER; + break; + } + if (spec.align_ != ALIGN_DEFAULT) { + if (p != s) { + if (c == '}') break; + if (c == '{') + FMT_THROW(FormatError("invalid fill character '{'")); + s += 2; + spec.fill_ = c; + } else ++s; + if (spec.align_ == ALIGN_NUMERIC) + require_numeric_argument(arg, '='); + break; + } + } while (--p >= s); + } + + // Parse sign. + switch (*s) { + case '+': + check_sign(s, arg); + spec.flags_ |= SIGN_FLAG | PLUS_FLAG; + break; + case '-': + check_sign(s, arg); + spec.flags_ |= MINUS_FLAG; + break; + case ' ': + check_sign(s, arg); + spec.flags_ |= SIGN_FLAG; + break; + } + + if (*s == '#') { + require_numeric_argument(arg, '#'); + spec.flags_ |= HASH_FLAG; + ++s; + } + + // Parse zero flag. + if (*s == '0') { + require_numeric_argument(arg, '0'); + spec.align_ = ALIGN_NUMERIC; + spec.fill_ = '0'; + ++s; + } + + // Parse width. + if ('0' <= *s && *s <= '9') { + spec.width_ = internal::parse_nonnegative_int(s); + } else if (*s == '{') { + ++s; + Arg width_arg = internal::is_name_start(*s) ? + parse_arg_name(s) : parse_arg_index(s); + if (*s++ != '}') + FMT_THROW(FormatError("invalid format string")); + ULongLong value = 0; + switch (width_arg.type) { + case Arg::INT: + if (width_arg.int_value < 0) + FMT_THROW(FormatError("negative width")); + value = width_arg.int_value; + break; + case Arg::UINT: + value = width_arg.uint_value; + break; + case Arg::LONG_LONG: + if (width_arg.long_long_value < 0) + FMT_THROW(FormatError("negative width")); + value = width_arg.long_long_value; + break; + case Arg::ULONG_LONG: + value = width_arg.ulong_long_value; + break; + default: + FMT_THROW(FormatError("width is not integer")); + } + if (value > (std::numeric_limits::max)()) + FMT_THROW(FormatError("number is too big")); + spec.width_ = static_cast(value); + } + + // Parse precision. + if (*s == '.') { + ++s; + spec.precision_ = 0; + if ('0' <= *s && *s <= '9') { + spec.precision_ = internal::parse_nonnegative_int(s); + } else if (*s == '{') { + ++s; + Arg precision_arg = internal::is_name_start(*s) ? + parse_arg_name(s) : parse_arg_index(s); + if (*s++ != '}') + FMT_THROW(FormatError("invalid format string")); + ULongLong value = 0; + switch (precision_arg.type) { + case Arg::INT: + if (precision_arg.int_value < 0) + FMT_THROW(FormatError("negative precision")); + value = precision_arg.int_value; + break; + case Arg::UINT: + value = precision_arg.uint_value; + break; + case Arg::LONG_LONG: + if (precision_arg.long_long_value < 0) + FMT_THROW(FormatError("negative precision")); + value = precision_arg.long_long_value; + break; + case Arg::ULONG_LONG: + value = precision_arg.ulong_long_value; + break; + default: + FMT_THROW(FormatError("precision is not integer")); + } + if (value > (std::numeric_limits::max)()) + FMT_THROW(FormatError("number is too big")); + spec.precision_ = static_cast(value); + } else { + FMT_THROW(FormatError("missing precision specifier")); + } + if (arg.type <= Arg::LAST_INTEGER_TYPE || arg.type == Arg::POINTER) { + FMT_THROW(FormatError( + fmt::format("precision not allowed in {} format specifier", + arg.type == Arg::POINTER ? "pointer" : "integer"))); + } + } + + // Parse type. + if (*s != '}' && *s) + spec.type_ = static_cast(*s++); + } + + if (*s++ != '}') + FMT_THROW(FormatError("missing '}' in format string")); + + // Format argument. + ArgFormatter(*this, spec, s - 1).visit(arg); + return s; +} + +template +void BasicFormatter::format(BasicCStringRef format_str) { + const Char *s = format_str.c_str(); + const Char *start = s; + while (*s) { + Char c = *s++; + if (c != '{' && c != '}') continue; + if (*s == c) { + write(writer_, start, s); + start = ++s; + continue; + } + if (c == '}') + FMT_THROW(FormatError("unmatched '}' in format string")); + write(writer_, start, s - 1); + internal::Arg arg = internal::is_name_start(*s) ? + parse_arg_name(s) : parse_arg_index(s); + start = s = format(s, arg); + } + write(writer_, start, s); +} +} // namespace fmt + +#if FMT_USE_USER_DEFINED_LITERALS +namespace fmt { +namespace internal { + +template +struct UdlFormat { + const Char *str; + + template + auto operator()(Args && ... args) const + -> decltype(format(str, std::forward(args)...)) { + return format(str, std::forward(args)...); + } +}; + +template +struct UdlArg { + const Char *str; + + template + NamedArg operator=(T &&value) const { + return {str, std::forward(value)}; + } +}; + +} // namespace internal + +inline namespace literals { + +/** + \rst + C++11 literal equivalent of :func:`fmt::format`. + + **Example**:: + + using namespace fmt::literals; + std::string message = "The answer is {}"_format(42); + \endrst + */ +inline internal::UdlFormat +operator"" _format(const char *s, std::size_t) { return {s}; } +inline internal::UdlFormat +operator"" _format(const wchar_t *s, std::size_t) { return {s}; } + +/** + \rst + C++11 literal equivalent of :func:`fmt::arg`. + + **Example**:: + + using namespace fmt::literals; + print("Elapsed time: {s:.2f} seconds", "s"_a=1.23); + \endrst + */ +inline internal::UdlArg +operator"" _a(const char *s, std::size_t) { return {s}; } +inline internal::UdlArg +operator"" _a(const wchar_t *s, std::size_t) { return {s}; } + +} // inline namespace literals +} // namespace fmt +#endif // FMT_USE_USER_DEFINED_LITERALS + +// Restore warnings. +#if FMT_GCC_VERSION >= 406 +# pragma GCC diagnostic pop +#endif + +#if defined(__clang__) && !defined(FMT_ICC_VERSION) +# pragma clang diagnostic pop +#endif + +#ifdef FMT_HEADER_ONLY +# define FMT_FUNC inline +# include "format.cc" +#else +# define FMT_FUNC +#endif + +#endif // FMT_FORMAT_H_ diff --git a/diy/include/diy/fmt/ostream.cc b/diy/include/diy/fmt/ostream.cc new file mode 100644 index 000000000..0ba303478 --- /dev/null +++ b/diy/include/diy/fmt/ostream.cc @@ -0,0 +1,61 @@ +/* + Formatting library for C++ - std::ostream support + + Copyright (c) 2012 - 2016, Victor Zverovich + All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are met: + + 1. Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + 2. Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR + ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + */ + +#include "ostream.h" + +namespace fmt { + +namespace { +// Write the content of w to os. +void write(std::ostream &os, Writer &w) { + const char *data = w.data(); + typedef internal::MakeUnsigned::Type UnsignedStreamSize; + UnsignedStreamSize size = w.size(); + UnsignedStreamSize max_size = + internal::to_unsigned((std::numeric_limits::max)()); + do { + UnsignedStreamSize n = size <= max_size ? size : max_size; + os.write(data, static_cast(n)); + data += n; + size -= n; + } while (size != 0); +} +} + +FMT_FUNC void print(std::ostream &os, CStringRef format_str, ArgList args) { + MemoryWriter w; + w.write(format_str, args); + write(os, w); +} + +FMT_FUNC int fprintf(std::ostream &os, CStringRef format, ArgList args) { + MemoryWriter w; + printf(w, format, args); + write(os, w); + return static_cast(w.size()); +} +} // namespace fmt diff --git a/diy/include/diy/fmt/ostream.h b/diy/include/diy/fmt/ostream.h new file mode 100644 index 000000000..812278dd3 --- /dev/null +++ b/diy/include/diy/fmt/ostream.h @@ -0,0 +1,133 @@ +/* + Formatting library for C++ - std::ostream support + + Copyright (c) 2012 - 2016, Victor Zverovich + All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are met: + + 1. Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + 2. Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR + ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + */ + +#ifndef FMT_OSTREAM_H_ +#define FMT_OSTREAM_H_ + +#include "format.h" +#include + +namespace fmt { + +namespace internal { + +template +class FormatBuf : public std::basic_streambuf { + private: + typedef typename std::basic_streambuf::int_type int_type; + typedef typename std::basic_streambuf::traits_type traits_type; + + Buffer &buffer_; + Char *start_; + + public: + FormatBuf(Buffer &buffer) : buffer_(buffer), start_(&buffer[0]) { + this->setp(start_, start_ + buffer_.capacity()); + } + + int_type overflow(int_type ch = traits_type::eof()) { + if (!traits_type::eq_int_type(ch, traits_type::eof())) { + size_t buf_size = size(); + buffer_.resize(buf_size); + buffer_.reserve(buf_size * 2); + + start_ = &buffer_[0]; + start_[buf_size] = traits_type::to_char_type(ch); + this->setp(start_+ buf_size + 1, start_ + buf_size * 2); + } + return ch; + } + + size_t size() const { + return to_unsigned(this->pptr() - start_); + } +}; + +Yes &convert(std::ostream &); + +struct DummyStream : std::ostream { + DummyStream(); // Suppress a bogus warning in MSVC. + // Hide all operator<< overloads from std::ostream. + void operator<<(Null<>); +}; + +No &operator<<(std::ostream &, int); + +template +struct ConvertToIntImpl { + // Convert to int only if T doesn't have an overloaded operator<<. + enum { + value = sizeof(convert(get() << get())) == sizeof(No) + }; +}; +} // namespace internal + +// Formats a value. +template +void format(BasicFormatter &f, + const Char *&format_str, const T &value) { + internal::MemoryBuffer buffer; + + internal::FormatBuf format_buf(buffer); + std::basic_ostream output(&format_buf); + output << value; + + BasicStringRef str(&buffer[0], format_buf.size()); + typedef internal::MakeArg< BasicFormatter > MakeArg; + format_str = f.format(format_str, MakeArg(str)); +} + +/** + \rst + Prints formatted data to the stream *os*. + + **Example**:: + + print(cerr, "Don't {}!", "panic"); + \endrst + */ +FMT_API void print(std::ostream &os, CStringRef format_str, ArgList args); +FMT_VARIADIC(void, print, std::ostream &, CStringRef) + +/** + \rst + Prints formatted data to the stream *os*. + + **Example**:: + + fprintf(cerr, "Don't %s!", "panic"); + \endrst + */ +FMT_API int fprintf(std::ostream &os, CStringRef format_str, ArgList args); +FMT_VARIADIC(int, fprintf, std::ostream &, CStringRef) +} // namespace fmt + +#ifdef FMT_HEADER_ONLY +# include "ostream.cc" +#endif + +#endif // FMT_OSTREAM_H_ diff --git a/diy/include/diy/grid.hpp b/diy/include/diy/grid.hpp new file mode 100644 index 000000000..cfdb72a65 --- /dev/null +++ b/diy/include/diy/grid.hpp @@ -0,0 +1,153 @@ +#ifndef DIY_GRID_HPP +#define DIY_GRID_HPP + +#include "point.hpp" + +namespace diy +{ + +template +struct Grid; + +template +struct GridRef +{ + public: + typedef C Value; + + typedef Point Vertex; + typedef size_t Index; + + public: + template + GridRef(C* data, const Point& shape, bool c_order = true): + data_(data), shape_(shape), c_order_(c_order) { set_stride(); } + + GridRef(Grid& g): + data_(g.data()), shape_(g.shape()), + c_order_(g.c_order()) { set_stride(); } + + template + C operator()(const Point& v) const { return (*this)(index(v)); } + + template + C& operator()(const Point& v) { return (*this)(index(v)); } + + C operator()(Index i) const { return data_[i]; } + C& operator()(Index i) { return data_[i]; } + + const Vertex& + shape() const { return shape_; } + + const C* + data() const { return data_; } + C* data() { return data_; } + + // Set every element to the given value + GridRef& operator=(C value) { Index s = size(); for (Index i = 0; i < s; ++i) data_[i] = value; return *this; } + GridRef& operator/=(C value) { Index s = size(); for (Index i = 0; i < s; ++i) data_[i] /= value; return *this; } + + Vertex vertex(Index idx) const { Vertex v; for (unsigned i = 0; i < D; ++i) { v[i] = idx / stride_[i]; idx %= stride_[i]; } return v; } + Index index(const Vertex& v) const { Index idx = 0; for (unsigned i = 0; i < D; ++i) { idx += ((Index) v[i]) * ((Index) stride_[i]); } return idx; } + + Index size() const { return size(shape()); } + void swap(GridRef& other) { std::swap(data_, other.data_); std::swap(shape_, other.shape_); std::swap(stride_, other.stride_); std::swap(c_order_, other.c_order_); } + + bool c_order() const { return c_order_; } + + static constexpr + unsigned dimension() { return D; } + + protected: + static Index + size(const Vertex& v) { Index res = 1; for (unsigned i = 0; i < D; ++i) res *= v[i]; return res; } + + void set_stride() + { + Index cur = 1; + if (c_order_) + for (unsigned i = D; i > 0; --i) { stride_[i-1] = cur; cur *= shape_[i-1]; } + else + for (unsigned i = 0; i < D; ++i) { stride_[i] = cur; cur *= shape_[i]; } + + } + void set_shape(const Vertex& v) { shape_ = v; set_stride(); } + void set_data(C* data) { data_ = data; } + void set_c_order(bool order) { c_order_ = order; } + + private: + C* data_; + Vertex shape_; + Vertex stride_; + bool c_order_; +}; + + +template +struct Grid: public GridRef +{ + public: + typedef GridRef Parent; + typedef typename Parent::Value Value; + typedef typename Parent::Index Index; + typedef typename Parent::Vertex Vertex; + typedef Parent Reference; + + template + struct rebind { typedef Grid type; }; + + public: + Grid(): + Parent(new C[0], Vertex::zero()) {} + template + Grid(const Point& shape, bool c_order = true): + Parent(new C[size(shape)], shape, c_order) + {} + + Grid(Grid&& g): Grid() { Parent::swap(g); } + + Grid(const Parent& g): + Parent(new C[size(g.shape())], g.shape(), + g.c_order()) { copy_data(g.data()); } + + template + Grid(const OtherGrid& g): + Parent(new C[size(g.shape())], + g.shape(), + g.c_order()) { copy_data(g.data()); } + + ~Grid() { delete[] Parent::data(); } + + template + Grid& operator=(const GridRef& other) + { + delete[] Parent::data(); + Parent::set_c_order(other.c_order()); // NB: order needs to be set before the shape, to set the stride correctly + Parent::set_shape(other.shape()); + Index s = size(shape()); + Parent::set_data(new C[s]); + copy_data(other.data()); + return *this; + } + + Grid& operator=(Grid&& g) { Parent::swap(g); return *this; } + + using Parent::data; + using Parent::shape; + using Parent::operator(); + using Parent::operator=; + using Parent::size; + + private: + template + void copy_data(const OC* data) + { + Index s = size(shape()); + for (Index i = 0; i < s; ++i) + Parent::data()[i] = data[i]; + } +}; + +} + +#endif diff --git a/diy/include/diy/io/block.hpp b/diy/include/diy/io/block.hpp new file mode 100644 index 000000000..05e45a800 --- /dev/null +++ b/diy/include/diy/io/block.hpp @@ -0,0 +1,396 @@ +#ifndef DIY_IO_BLOCK_HPP +#define DIY_IO_BLOCK_HPP + +#include +#include +#include + +#include +#include +#include + +#include "../mpi.hpp" +#include "../assigner.hpp" +#include "../master.hpp" +#include "../storage.hpp" +#include "../log.hpp" + +// Read and write collections of blocks using MPI-IO +namespace diy +{ +namespace io +{ + namespace detail + { + typedef mpi::io::offset offset_t; + + struct GidOffsetCount + { + GidOffsetCount(): // need to initialize a vector of given size + gid(-1), offset(0), count(0) {} + + GidOffsetCount(int gid_, offset_t offset_, offset_t count_): + gid(gid_), offset(offset_), count(count_) {} + + bool operator<(const GidOffsetCount& other) const { return gid < other.gid; } + + int gid; + offset_t offset; + offset_t count; + }; + } +} + +// Serialize GidOffsetCount explicitly, to avoid alignment and unitialized data issues +// (to get identical output files given the same block input) +template<> +struct Serialization +{ + typedef io::detail::GidOffsetCount GidOffsetCount; + + static void save(BinaryBuffer& bb, const GidOffsetCount& x) + { + diy::save(bb, x.gid); + diy::save(bb, x.offset); + diy::save(bb, x.count); + } + + static void load(BinaryBuffer& bb, GidOffsetCount& x) + { + diy::load(bb, x.gid); + diy::load(bb, x.offset); + diy::load(bb, x.count); + } +}; + +namespace io +{ +/** + * \ingroup IO + * \brief Write blocks to storage collectively in one shared file + */ + inline + void + write_blocks(const std::string& outfilename, //!< output file name + const mpi::communicator& comm, //!< communicator + Master& master, //!< master object + const MemoryBuffer& extra = MemoryBuffer(),//!< user-defined metadata for file header; meaningful only on rank == 0 + Master::SaveBlock save = 0) //!< block save function in case different than or undefined in the master + { + if (!save) save = master.saver(); // save is likely to be different from master.save() + + typedef detail::offset_t offset_t; + typedef detail::GidOffsetCount GidOffsetCount; + + unsigned size = master.size(), + max_size, min_size; + mpi::all_reduce(comm, size, max_size, mpi::maximum()); + mpi::all_reduce(comm, size, min_size, mpi::minimum()); + + // truncate the file + if (comm.rank() == 0) + truncate(outfilename.c_str(), 0); + + mpi::io::file f(comm, outfilename, mpi::io::file::wronly | mpi::io::file::create); + + offset_t start = 0, shift; + std::vector offset_counts; + unsigned i; + for (i = 0; i < max_size; ++i) + { + offset_t count = 0, + offset; + if (i < size) + { + // get the block from master and serialize it + const void* block = master.get(i); + MemoryBuffer bb; + LinkFactory::save(bb, master.link(i)); + save(block, bb); + count = bb.buffer.size(); + mpi::scan(comm, count, offset, std::plus()); + offset += start - count; + mpi::all_reduce(comm, count, shift, std::plus()); + start += shift; + + if (i < min_size) // up to min_size, we can do collective IO + f.write_at_all(offset, bb.buffer); + else + f.write_at(offset, bb.buffer); + + offset_counts.push_back(GidOffsetCount(master.gid(i), offset, count)); + } else + { + // matching global operations + mpi::scan(comm, count, offset, std::plus()); + mpi::all_reduce(comm, count, shift, std::plus()); + + // -1 indicates that there is no block written here from this rank + offset_counts.push_back(GidOffsetCount(-1, offset, count)); + } + } + + if (comm.rank() == 0) + { + // round-about way of gather vector of vectors of GidOffsetCount to avoid registering a new mpi datatype + std::vector< std::vector > gathered_offset_count_buffers; + MemoryBuffer oc_buffer; diy::save(oc_buffer, offset_counts); + mpi::gather(comm, oc_buffer.buffer, gathered_offset_count_buffers, 0); + + std::vector all_offset_counts; + for (unsigned i = 0; i < gathered_offset_count_buffers.size(); ++i) + { + MemoryBuffer oc_buffer; oc_buffer.buffer.swap(gathered_offset_count_buffers[i]); + std::vector offset_counts; + diy::load(oc_buffer, offset_counts); + for (unsigned j = 0; j < offset_counts.size(); ++j) + if (offset_counts[j].gid != -1) + all_offset_counts.push_back(offset_counts[j]); + } + std::sort(all_offset_counts.begin(), all_offset_counts.end()); // sorts by gid + + MemoryBuffer bb; + diy::save(bb, all_offset_counts); + diy::save(bb, extra); + size_t footer_size = bb.size(); + diy::save(bb, footer_size); + + // find footer_offset as the max of (offset + count) + offset_t footer_offset = 0; + for (unsigned i = 0; i < all_offset_counts.size(); ++i) + { + offset_t end = all_offset_counts[i].offset + all_offset_counts[i].count; + if (end > footer_offset) + footer_offset = end; + } + f.write_at(footer_offset, bb.buffer); + } else + { + MemoryBuffer oc_buffer; diy::save(oc_buffer, offset_counts); + mpi::gather(comm, oc_buffer.buffer, 0); + } + } + +/** + * \ingroup IO + * \brief Read blocks from storage collectively from one shared file + */ + inline + void + read_blocks(const std::string& infilename, //!< input file name + const mpi::communicator& comm, //!< communicator + Assigner& assigner, //!< assigner object + Master& master, //!< master object + MemoryBuffer& extra, //!< user-defined metadata in file header + Master::LoadBlock load = 0) //!< load block function in case different than or unefined in the master + { + if (!load) load = master.loader(); // load is likely to be different from master.load() + + typedef detail::offset_t offset_t; + typedef detail::GidOffsetCount GidOffsetCount; + + mpi::io::file f(comm, infilename, mpi::io::file::rdonly); + + offset_t footer_offset = f.size() - sizeof(size_t); + size_t footer_size; + + // Read the size + f.read_at_all(footer_offset, (char*) &footer_size, sizeof(footer_size)); + + // Read all_offset_counts + footer_offset -= footer_size; + MemoryBuffer footer; + footer.buffer.resize(footer_size); + f.read_at_all(footer_offset, footer.buffer); + + std::vector all_offset_counts; + diy::load(footer, all_offset_counts); + diy::load(footer, extra); + extra.reset(); + + // Get local gids from assigner + size_t size = all_offset_counts.size(); + assigner.set_nblocks(size); + std::vector gids; + assigner.local_gids(comm.rank(), gids); + + for (unsigned i = 0; i < gids.size(); ++i) + { + if (gids[i] != all_offset_counts[gids[i]].gid) + get_logger()->warn("gids don't match in diy::io::read_blocks(), {} vs {}", + gids[i], all_offset_counts[gids[i]].gid); + + offset_t offset = all_offset_counts[gids[i]].offset, + count = all_offset_counts[gids[i]].count; + MemoryBuffer bb; + bb.buffer.resize(count); + f.read_at(offset, bb.buffer); + Link* l = LinkFactory::load(bb); + l->fix(assigner); + void* b = master.create(); + load(b, bb); + master.add(gids[i], b, l); + } + } + + + // Functions without the extra buffer, for compatibility with the old code + inline + void + write_blocks(const std::string& outfilename, + const mpi::communicator& comm, + Master& master, + Master::SaveBlock save) + { + MemoryBuffer extra; + write_blocks(outfilename, comm, master, extra, save); + } + + inline + void + read_blocks(const std::string& infilename, + const mpi::communicator& comm, + Assigner& assigner, + Master& master, + Master::LoadBlock load = 0) + { + MemoryBuffer extra; // dummy + read_blocks(infilename, comm, assigner, master, extra, load); + } + +namespace split +{ +/** + * \ingroup IO + * \brief Write blocks to storage independently in one file per process + */ + inline + void + write_blocks(const std::string& outfilename, //!< output file name + const mpi::communicator& comm, //!< communicator + Master& master, //!< master object + const MemoryBuffer& extra = MemoryBuffer(),//!< user-defined metadata for file header; meaningful only on rank == 0 + Master::SaveBlock save = 0) //!< block save function in case different than or undefined in master + { + if (!save) save = master.saver(); // save is likely to be different from master.save() + + bool proceed = false; + size_t size = 0; + if (comm.rank() == 0) + { + struct stat s; + if (stat(outfilename.c_str(), &s) == 0) + { + if (S_ISDIR(s.st_mode)) + proceed = true; + } else if (mkdir(outfilename.c_str(), 0755) == 0) + proceed = true; + mpi::broadcast(comm, proceed, 0); + mpi::reduce(comm, (size_t) master.size(), size, 0, std::plus()); + } else + { + mpi::broadcast(comm, proceed, 0); + mpi::reduce(comm, (size_t) master.size(), 0, std::plus()); + } + + if (!proceed) + throw std::runtime_error("Cannot access or create directory: " + outfilename); + + for (int i = 0; i < (int)master.size(); ++i) + { + const void* block = master.get(i); + + std::string filename = fmt::format("{}/{}", outfilename, master.gid(i)); + + ::diy::detail::FileBuffer bb(fopen(filename.c_str(), "w")); + + LinkFactory::save(bb, master.link(i)); + save(block, bb); + + fclose(bb.file); + } + + if (comm.rank() == 0) + { + // save the extra buffer + std::string filename = outfilename + "/extra"; + ::diy::detail::FileBuffer bb(fopen(filename.c_str(), "w")); + ::diy::save(bb, size); + ::diy::save(bb, extra); + fclose(bb.file); + } + } + +/** + * \ingroup IO + * \brief Read blocks from storage independently from one file per process + */ + inline + void + read_blocks(const std::string& infilename, //!< input file name + const mpi::communicator& comm, //!< communicator + Assigner& assigner, //!< assigner object + Master& master, //!< master object + MemoryBuffer& extra, //!< user-defined metadata in file header + Master::LoadBlock load = 0) //!< block load function in case different than or undefined in master + { + if (!load) load = master.loader(); // load is likely to be different from master.load() + + // load the extra buffer and size + size_t size; + std::string filename = infilename + "/extra"; + ::diy::detail::FileBuffer bb(fopen(filename.c_str(), "r")); + ::diy::load(bb, size); + ::diy::load(bb, extra); + extra.reset(); + fclose(bb.file); + + // Get local gids from assigner + assigner.set_nblocks(size); + std::vector gids; + assigner.local_gids(comm.rank(), gids); + + // Read our blocks; + for (unsigned i = 0; i < gids.size(); ++i) + { + std::string filename = fmt::format("{}/{}", infilename, gids[i]); + + ::diy::detail::FileBuffer bb(fopen(filename.c_str(), "r")); + Link* l = LinkFactory::load(bb); + l->fix(assigner); + void* b = master.create(); + load(b, bb); + master.add(gids[i], b, l); + + fclose(bb.file); + } + } + + // Functions without the extra buffer, for compatibility with the old code + inline + void + write_blocks(const std::string& outfilename, + const mpi::communicator& comm, + Master& master, + Master::SaveBlock save) + { + MemoryBuffer extra; + write_blocks(outfilename, comm, master, extra, save); + } + + inline + void + read_blocks(const std::string& infilename, + const mpi::communicator& comm, + Assigner& assigner, + Master& master, + Master::LoadBlock load = 0) + { + MemoryBuffer extra; // dummy + read_blocks(infilename, comm, assigner, master, extra, load); + } +} // split +} // io +} // diy + +#endif diff --git a/diy/include/diy/io/bov.hpp b/diy/include/diy/io/bov.hpp new file mode 100644 index 000000000..bd8b24009 --- /dev/null +++ b/diy/include/diy/io/bov.hpp @@ -0,0 +1,171 @@ +#ifndef DIY_IO_BOV_HPP +#define DIY_IO_BOV_HPP + +#include +#include +#include + +#include "../types.hpp" +#include "../mpi.hpp" + +namespace diy +{ +namespace io +{ + // Reads and writes subsets of a block of values into specified block bounds + class BOV + { + public: + typedef std::vector Shape; + public: + BOV(mpi::io::file& f): + f_(f), offset_(0) {} + + template + BOV(mpi::io::file& f, + const S& shape = S(), + mpi::io::offset offset = 0): + f_(f), offset_(offset) { set_shape(shape); } + + void set_offset(mpi::io::offset offset) { offset_ = offset; } + + template + void set_shape(const S& shape) + { + shape_.clear(); + stride_.clear(); + for (unsigned i = 0; i < shape.size(); ++i) + { + shape_.push_back(shape[i]); + stride_.push_back(1); + } + for (int i = shape_.size() - 2; i >= 0; --i) + stride_[i] = stride_[i+1] * shape_[i+1]; + } + + const Shape& shape() const { return shape_; } + + template + void read(const DiscreteBounds& bounds, T* buffer, bool collective = false, int chunk = 1) const; + + template + void write(const DiscreteBounds& bounds, const T* buffer, bool collective = false, int chunk = 1); + + template + void write(const DiscreteBounds& bounds, const T* buffer, const DiscreteBounds& core, bool collective = false, int chunk = 1); + + protected: + mpi::io::file& file() { return f_; } + + private: + mpi::io::file& f_; + Shape shape_; + std::vector stride_; + size_t offset_; + }; +} +} + +template +void +diy::io::BOV:: +read(const DiscreteBounds& bounds, T* buffer, bool collective, int chunk) const +{ + int dim = shape_.size(); + int total = 1; + std::vector subsizes; + for (int i = 0; i < dim; ++i) + { + subsizes.push_back(bounds.max[i] - bounds.min[i] + 1); + total *= subsizes.back(); + } + + MPI_Datatype T_type; + if (chunk == 1) + T_type = mpi::detail::get_mpi_datatype(); + else + { + // create an MPI struct of size chunk to read the data in those chunks + // (this allows to work around MPI-IO weirdness where crucial quantities + // are ints, which are too narrow of a type) + int array_of_blocklengths[] = { chunk }; + MPI_Aint array_of_displacements[] = { 0 }; + MPI_Datatype array_of_types[] = { mpi::detail::get_mpi_datatype() }; + MPI_Type_create_struct(1, array_of_blocklengths, array_of_displacements, array_of_types, &T_type); + MPI_Type_commit(&T_type); + } + + MPI_Datatype fileblk; + MPI_Type_create_subarray(dim, (int*) &shape_[0], &subsizes[0], (int*) &bounds.min[0], MPI_ORDER_C, T_type, &fileblk); + MPI_Type_commit(&fileblk); + + MPI_File_set_view(f_.handle(), offset_, T_type, fileblk, (char*)"native", MPI_INFO_NULL); + + mpi::status s; + if (!collective) + MPI_File_read(f_.handle(), buffer, total, T_type, &s.s); + else + MPI_File_read_all(f_.handle(), buffer, total, T_type, &s.s); + + if (chunk != 1) + MPI_Type_free(&T_type); + MPI_Type_free(&fileblk); +} + +template +void +diy::io::BOV:: +write(const DiscreteBounds& bounds, const T* buffer, bool collective, int chunk) +{ + write(bounds, buffer, bounds, collective, chunk); +} + +template +void +diy::io::BOV:: +write(const DiscreteBounds& bounds, const T* buffer, const DiscreteBounds& core, bool collective, int chunk) +{ + int dim = shape_.size(); + std::vector subsizes; + std::vector buffer_shape, buffer_start; + for (int i = 0; i < dim; ++i) + { + buffer_shape.push_back(bounds.max[i] - bounds.min[i] + 1); + buffer_start.push_back(core.min[i] - bounds.min[i]); + subsizes.push_back(core.max[i] - core.min[i] + 1); + } + + MPI_Datatype T_type; + if (chunk == 1) + T_type = mpi::detail::get_mpi_datatype(); + else + { + // assume T is a binary block and create an MPI struct of appropriate size + int array_of_blocklengths[] = { chunk }; + MPI_Aint array_of_displacements[] = { 0 }; + MPI_Datatype array_of_types[] = { mpi::detail::get_mpi_datatype() }; + MPI_Type_create_struct(1, array_of_blocklengths, array_of_displacements, array_of_types, &T_type); + MPI_Type_commit(&T_type); + } + + MPI_Datatype fileblk, subbuffer; + MPI_Type_create_subarray(dim, (int*) &shape_[0], &subsizes[0], (int*) &bounds.min[0], MPI_ORDER_C, T_type, &fileblk); + MPI_Type_create_subarray(dim, (int*) &buffer_shape[0], &subsizes[0], (int*) &buffer_start[0], MPI_ORDER_C, T_type, &subbuffer); + MPI_Type_commit(&fileblk); + MPI_Type_commit(&subbuffer); + + MPI_File_set_view(f_.handle(), offset_, T_type, fileblk, (char*)"native", MPI_INFO_NULL); + + mpi::status s; + if (!collective) + MPI_File_write(f_.handle(), (void*)buffer, 1, subbuffer, &s.s); + else + MPI_File_write_all(f_.handle(), (void*)buffer, 1, subbuffer, &s.s); + + if (chunk != 1) + MPI_Type_free(&T_type); + MPI_Type_free(&fileblk); + MPI_Type_free(&subbuffer); +} + +#endif diff --git a/diy/include/diy/io/numpy.hpp b/diy/include/diy/io/numpy.hpp new file mode 100644 index 000000000..0199a0c38 --- /dev/null +++ b/diy/include/diy/io/numpy.hpp @@ -0,0 +1,213 @@ +#ifndef DIY_IO_NMPY_HPP +#define DIY_IO_NMPY_HPP + +#include +#include +#include + +#include "../serialization.hpp" +#include "bov.hpp" + +namespace diy +{ +namespace io +{ + class NumPy: public BOV + { + public: + NumPy(mpi::io::file& f): + BOV(f) {} + + unsigned word_size() const { return word_size_; } + + unsigned read_header() + { + BOV::Shape shape; + bool fortran; + size_t offset = parse_npy_header(shape, fortran); + if (fortran) + throw std::runtime_error("diy::io::NumPy cannot read data in fortran order"); + BOV::set_offset(offset); + BOV::set_shape(shape); + return word_size_; + } + + template + void write_header(int dim, const DiscreteBounds& bounds); + + template + void write_header(const S& shape); + + private: + inline size_t parse_npy_header(BOV::Shape& shape, bool& fortran_order); + void save(diy::BinaryBuffer& bb, const std::string& s) { bb.save_binary(s.c_str(), s.size()); } + template + inline void convert_and_save(diy::BinaryBuffer& bb, const T& x) + { + std::ostringstream oss; + oss << x; + save(bb, oss.str()); + } + + private: + unsigned word_size_; + }; + + namespace detail + { + inline char big_endian(); + template + char map_numpy_type(); + } +} +} + +// Modified from: https://github.com/rogersce/cnpy +// Copyright (C) 2011 Carl Rogers +// Released under MIT License +// license available at http://www.opensource.org/licenses/mit-license.php +size_t +diy::io::NumPy:: +parse_npy_header(BOV::Shape& shape, bool& fortran_order) +{ + char buffer[256]; + file().read_at_all(0, buffer, 256); + std::string header(buffer, buffer + 256); + size_t nl = header.find('\n'); + if (nl == std::string::npos) + throw std::runtime_error("parse_npy_header: failed to read the header"); + header = header.substr(11, nl - 11 + 1); + size_t header_size = nl + 1; + + int loc1, loc2; + + //fortran order + loc1 = header.find("fortran_order")+16; + fortran_order = (header.substr(loc1,4) == "True" ? true : false); + + //shape + unsigned ndims; + loc1 = header.find("("); + loc2 = header.find(")"); + std::string str_shape = header.substr(loc1+1,loc2-loc1-1); + if(str_shape[str_shape.size()-1] == ',') ndims = 1; + else ndims = std::count(str_shape.begin(),str_shape.end(),',')+1; + shape.resize(ndims); + for(unsigned int i = 0;i < ndims;i++) { + loc1 = str_shape.find(","); + shape[i] = atoi(str_shape.substr(0,loc1).c_str()); + str_shape = str_shape.substr(loc1+1); + } + + //endian, word size, data type + //byte order code | stands for not applicable. + //not sure when this applies except for byte array + loc1 = header.find("descr")+9; + //bool littleEndian = (header[loc1] == '<' || header[loc1] == '|' ? true : false); + //assert(littleEndian); + + //char type = header[loc1+1]; + //assert(type == map_type(T)); + + std::string str_ws = header.substr(loc1+2); + loc2 = str_ws.find("'"); + word_size_ = atoi(str_ws.substr(0,loc2).c_str()); + + return header_size; +} + +template +void +diy::io::NumPy:: +write_header(int dim, const DiscreteBounds& bounds) +{ + std::vector shape; + for (int i = 0; i < dim; ++i) + shape.push_back(bounds.max[i] - bounds.min[i] + 1); + + write_header< T, std::vector >(shape); +} + + +template +void +diy::io::NumPy:: +write_header(const S& shape) +{ + BOV::set_shape(shape); + + diy::MemoryBuffer dict; + save(dict, "{'descr': '"); + diy::save(dict, detail::big_endian()); + diy::save(dict, detail::map_numpy_type()); + convert_and_save(dict, sizeof(T)); + save(dict, "', 'fortran_order': False, 'shape': ("); + convert_and_save(dict, shape[0]); + for (int i = 1; i < (int) shape.size(); i++) + { + save(dict, ", "); + convert_and_save(dict, shape[i]); + } + if(shape.size() == 1) save(dict, ","); + save(dict, "), }"); + //pad with spaces so that preamble+dict is modulo 16 bytes. preamble is 10 bytes. dict needs to end with \n + int remainder = 16 - (10 + dict.position) % 16; + for (int i = 0; i < remainder - 1; ++i) + diy::save(dict, ' '); + diy::save(dict, '\n'); + + diy::MemoryBuffer header; + diy::save(header, (char) 0x93); + save(header, "NUMPY"); + diy::save(header, (char) 0x01); // major version of numpy format + diy::save(header, (char) 0x00); // minor version of numpy format + diy::save(header, (unsigned short) dict.position); + header.save_binary(&dict.buffer[0], dict.buffer.size()); + + BOV::set_offset(header.position); + + if (file().comm().rank() == 0) + file().write_at(0, &header.buffer[0], header.buffer.size()); +} + +char +diy::io::detail::big_endian() +{ + unsigned char x[] = {1,0}; + void* x_void = x; + short y = *static_cast(x_void); + return y == 1 ? '<' : '>'; +} + +namespace diy +{ +namespace io +{ +namespace detail +{ +template<> inline char map_numpy_type() { return 'f'; } +template<> inline char map_numpy_type() { return 'f'; } +template<> inline char map_numpy_type() { return 'f'; } + +template<> inline char map_numpy_type() { return 'i'; } +template<> inline char map_numpy_type() { return 'i'; } +template<> inline char map_numpy_type() { return 'i'; } +template<> inline char map_numpy_type() { return 'i'; } +template<> inline char map_numpy_type() { return 'i'; } + +template<> inline char map_numpy_type() { return 'u'; } +template<> inline char map_numpy_type() { return 'u'; } +template<> inline char map_numpy_type() { return 'u'; } +template<> inline char map_numpy_type() { return 'u'; } +template<> inline char map_numpy_type() { return 'u'; } + +template<> inline char map_numpy_type() { return 'b'; } + +template<> inline char map_numpy_type< std::complex >() { return 'c'; } +template<> inline char map_numpy_type< std::complex >() { return 'c'; } +template<> inline char map_numpy_type< std::complex >() { return 'c'; } +} +} +} + +#endif diff --git a/diy/include/diy/link.hpp b/diy/include/diy/link.hpp new file mode 100644 index 000000000..3262eef61 --- /dev/null +++ b/diy/include/diy/link.hpp @@ -0,0 +1,219 @@ +#ifndef DIY_COVER_HPP +#define DIY_COVER_HPP + +#include +#include +#include + +#include "types.hpp" +#include "serialization.hpp" +#include "assigner.hpp" + +namespace diy +{ + // Local view of a distributed representation of a cover, a completely unstructured link + class Link + { + public: + virtual ~Link() {} // need to be able to delete derived classes + + int size() const { return neighbors_.size(); } + inline + int size_unique() const; + BlockID target(int i) const { return neighbors_[i]; } + BlockID& target(int i) { return neighbors_[i]; } + inline + int find(int gid) const; + + void add_neighbor(const BlockID& block) { neighbors_.push_back(block); } + + void fix(const Assigner& assigner) { for (unsigned i = 0; i < neighbors_.size(); ++i) { neighbors_[i].proc = assigner.rank(neighbors_[i].gid); } } + + void swap(Link& other) { neighbors_.swap(other.neighbors_); } + + virtual void save(BinaryBuffer& bb) const { diy::save(bb, neighbors_); } + virtual void load(BinaryBuffer& bb) { diy::load(bb, neighbors_); } + + virtual size_t id() const { return 0; } + + private: + std::vector neighbors_; + }; + + template + class RegularLink; + + typedef RegularLink RegularGridLink; + typedef RegularLink RegularContinuousLink; + + // Selector between regular discrete and contious links given bounds type + template + struct RegularLinkSelector; + + template<> + struct RegularLinkSelector + { + typedef RegularGridLink type; + static const size_t id = 1; + }; + + template<> + struct RegularLinkSelector + { + typedef RegularContinuousLink type; + static const size_t id = 2; + }; + + + // for a regular decomposition, it makes sense to address the neighbors by direction + // and store local and neighbor bounds + template + class RegularLink: public Link + { + public: + typedef Bounds_ Bounds; + + typedef std::map DirMap; + typedef std::vector DirVec; + + public: + RegularLink(int dim, const Bounds& core, const Bounds& bounds): + dim_(dim), core_(core), bounds_(bounds) {} + + // dimension + int dimension() const { return dim_; } + + // direction + int direction(Direction dir) const; // convert direction to a neighbor (-1 if no neighbor) + Direction direction(int i) const { return dir_vec_[i]; } + void add_direction(Direction dir) { int c = dir_map_.size(); dir_map_[dir] = c; dir_vec_.push_back(dir); } + + // wrap + void add_wrap(Direction dir) { wrap_.push_back(dir); } + Direction wrap(int i) const { return wrap_[i]; } + Direction& wrap(int i) { return wrap_[i]; } + + // bounds + const Bounds& core() const { return core_; } + Bounds& core() { return core_; } + const Bounds& bounds() const { return bounds_; } + Bounds& bounds() { return bounds_; } + const Bounds& bounds(int i) const { return nbr_bounds_[i]; } + void add_bounds(const Bounds& bounds) { nbr_bounds_.push_back(bounds); } + + void swap(RegularLink& other) { Link::swap(other); dir_map_.swap(other.dir_map_); dir_vec_.swap(other.dir_vec_); nbr_bounds_.swap(other.nbr_bounds_); std::swap(dim_, other.dim_); wrap_.swap(other.wrap_); std::swap(core_, other.core_); std::swap(bounds_, other.bounds_); } + + void save(BinaryBuffer& bb) const + { + Link::save(bb); + diy::save(bb, dim_); + diy::save(bb, dir_map_); + diy::save(bb, dir_vec_); + diy::save(bb, core_); + diy::save(bb, bounds_); + diy::save(bb, nbr_bounds_); + diy::save(bb, wrap_); + } + + void load(BinaryBuffer& bb) + { + Link::load(bb); + diy::load(bb, dim_); + diy::load(bb, dir_map_); + diy::load(bb, dir_vec_); + diy::load(bb, core_); + diy::load(bb, bounds_); + diy::load(bb, nbr_bounds_); + diy::load(bb, wrap_); + } + + virtual size_t id() const { return RegularLinkSelector::id; } + + private: + int dim_; + + DirMap dir_map_; + DirVec dir_vec_; + + Bounds core_; + Bounds bounds_; + std::vector nbr_bounds_; + std::vector wrap_; + }; + + // Other cover candidates: KDTreeLink, AMRGridLink + + struct LinkFactory + { + public: + static Link* create(size_t id) + { + // not pretty, but will do for now + if (id == 0) + return new Link; + else if (id == 1) + return new RegularGridLink(0, DiscreteBounds(), DiscreteBounds()); + else if (id == 2) + return new RegularContinuousLink(0, ContinuousBounds(), ContinuousBounds()); + else + return 0; + } + + inline static void save(BinaryBuffer& bb, const Link* l); + inline static Link* load(BinaryBuffer& bb); + }; +} + + +void +diy::LinkFactory:: +save(BinaryBuffer& bb, const Link* l) +{ + diy::save(bb, l->id()); + l->save(bb); +} + +diy::Link* +diy::LinkFactory:: +load(BinaryBuffer& bb) +{ + size_t id; + diy::load(bb, id); + Link* l = create(id); + l->load(bb); + return l; +} + +int +diy::Link:: +find(int gid) const +{ + for (unsigned i = 0; i < (unsigned)size(); ++i) + { + if (target(i).gid == gid) + return i; + } + return -1; +} +int +diy::Link:: +size_unique() const +{ + std::vector tmp(neighbors_.begin(), neighbors_.end()); + std::sort(tmp.begin(), tmp.end()); + return std::unique(tmp.begin(), tmp.end()) - tmp.begin(); +} + +template +int +diy::RegularLink:: +direction(Direction dir) const +{ + DirMap::const_iterator it = dir_map_.find(dir); + if (it == dir_map_.end()) + return -1; + else + return it->second; +} + +#endif diff --git a/diy/include/diy/log.hpp b/diy/include/diy/log.hpp new file mode 100644 index 000000000..45f202f92 --- /dev/null +++ b/diy/include/diy/log.hpp @@ -0,0 +1,103 @@ +#ifndef DIY_LOG_HPP +#define DIY_LOG_HPP + +#ifndef DIY_USE_SPDLOG + +#include +#include "fmt/format.h" +#include "fmt/ostream.h" + +namespace diy +{ + +namespace spd +{ + struct logger + { + // logger.info(cppformat_string, arg1, arg2, arg3, ...) call style + template void trace(const char* fmt, const Args&... args) {} + template void debug(const char* fmt, const Args&... args) {} + template void info(const char* fmt, const Args&... args) {} + template void warn(const char* fmt, const Args&... args) {} + template void error(const char* fmt, const Args&... args) {} + template void critical(const char* fmt, const Args&... args) {} + }; +} + +inline +std::shared_ptr +get_logger() +{ + return std::make_shared(); +} + +inline +std::shared_ptr +create_logger(std::string) +{ + return std::make_shared(); +} + +template +std::shared_ptr +set_logger(Args... args) +{ + return std::make_shared(); +} + +} // diy + +#else // DIY_USE_SPDLOG + +#include + +#include +#include + +#include +#include + +namespace diy +{ + +namespace spd = ::spdlog; + +inline +std::shared_ptr +get_logger() +{ + auto log = spd::get("diy"); + if (!log) + { + auto null_sink = std::make_shared (); + log = std::make_shared("null_logger", null_sink); + } + return log; +} + +inline +std::shared_ptr +create_logger(std::string log_level) +{ + auto log = spd::stderr_logger_mt("diy"); + int lvl; + for (lvl = spd::level::trace; lvl < spd::level::off; ++lvl) + if (spd::level::level_names[lvl] == log_level) + break; + log->set_level(static_cast(lvl)); + return log; +} + +template +std::shared_ptr +set_logger(Args... args) +{ + auto log = std::make_shared("diy", args...); + return log; +} + +} // diy +#endif + + +#endif // DIY_LOG_HPP diff --git a/diy/include/diy/master.hpp b/diy/include/diy/master.hpp new file mode 100644 index 000000000..ec7319a60 --- /dev/null +++ b/diy/include/diy/master.hpp @@ -0,0 +1,1205 @@ +#ifndef DIY_MASTER_HPP +#define DIY_MASTER_HPP + +#include +#include +#include +#include +#include +#include + +#include "link.hpp" +#include "collection.hpp" + +// Communicator functionality +#include "mpi.hpp" +#include "serialization.hpp" +#include "detail/collectives.hpp" +#include "time.hpp" + +#include "thread.hpp" + +#include "detail/block_traits.hpp" + +#include "log.hpp" +#include "stats.hpp" + +namespace diy +{ + // Stores and manages blocks; initiates serialization and communication when necessary. + // + // Provides a foreach function, which is meant as the main entry point. + // + // Provides a conversion between global and local block ids, + // which is hidden from blocks via a communicator proxy. + class Master + { + public: + struct ProcessBlock; + + template + struct Binder; + + // Commands + struct BaseCommand; + + template + struct Command; + + typedef std::vector Commands; + + // Skip + using Skip = std::function; + + struct SkipNoIncoming; + struct NeverSkip { bool operator()(int i, const Master& master) const { return false; } }; + + // Collection + typedef Collection::Create CreateBlock; + typedef Collection::Destroy DestroyBlock; + typedef Collection::Save SaveBlock; + typedef Collection::Load LoadBlock; + + public: + // Communicator types + struct Proxy; + struct ProxyWithLink; + + // foreach callback + template + using Callback = std::function; + + struct QueuePolicy + { + virtual bool unload_incoming(const Master& master, int from, int to, size_t size) const =0; + virtual bool unload_outgoing(const Master& master, int from, size_t size) const =0; + virtual ~QueuePolicy() {} + }; + + //! Move queues out of core if their size exceeds a parameter given in the constructor + struct QueueSizePolicy: public QueuePolicy + { + QueueSizePolicy(size_t sz): size(sz) {} + bool unload_incoming(const Master& master, int from, int to, size_t sz) const { return sz > size; } + bool unload_outgoing(const Master& master, int from, size_t sz) const { return sz > size*master.outgoing_count(from); } + + size_t size; + }; + + struct MessageInfo + { + int from, to; + int round; + }; + + struct InFlightSend + { + std::shared_ptr message; + mpi::request request; + + // for debug purposes: + MessageInfo info; + }; + + struct InFlightRecv + { + MemoryBuffer message; + MessageInfo info{ -1, -1, -1 }; + }; + + struct Collective; + struct tags { enum { queue, piece }; }; + + typedef std::list InFlightSendsList; + typedef std::map InFlightRecvsMap; + typedef std::list ToSendList; // [gid] + typedef std::list CollectivesList; + typedef std::map CollectivesMap; // gid -> [collectives] + + + struct QueueRecord + { + QueueRecord(size_t s = 0, int e = -1): size(s), external(e) {} + size_t size; + int external; + }; + + typedef std::map InQueueRecords; // gid -> (size, external) + typedef std::map IncomingQueues; // gid -> queue + typedef std::map OutgoingQueues; // (gid, proc) -> queue + typedef std::map OutQueueRecords; // (gid, proc) -> (size, external) + struct IncomingQueuesRecords + { + InQueueRecords records; + IncomingQueues queues; + }; + struct OutgoingQueuesRecord + { + OutgoingQueuesRecord(int e = -1): external(e) {} + int external; + OutQueueRecords external_local; + OutgoingQueues queues; + }; + typedef std::map IncomingQueuesMap; // gid -> { gid -> queue } + typedef std::map OutgoingQueuesMap; // gid -> { (gid,proc) -> queue } + + struct IncomingRound + { + IncomingQueuesMap map; + int received{0}; + }; + typedef std::map IncomingRoundMap; + + + public: + /** + * \ingroup Initialization + * \brief The main DIY object + * + * Helper functions specify how to: + * create an empty block, + * destroy a block (a function that's expected to upcast and delete), + * serialize a block + */ + Master(mpi::communicator comm, //!< communicator + int threads = 1, //!< number of threads DIY can use + int limit = -1, //!< number of blocks to store in memory + CreateBlock create = 0, //!< block create function; master manages creation if create != 0 + DestroyBlock destroy = 0, //!< block destroy function; master manages destruction if destroy != 0 + ExternalStorage* storage = 0, //!< storage object (path, method, etc.) for storing temporary blocks being shuffled in/out of core + SaveBlock save = 0, //!< block save function; master manages saving if save != 0 + LoadBlock load = 0, //!< block load function; master manages loading if load != 0 + QueuePolicy* q_policy = new QueueSizePolicy(4096)): //!< policy for managing message queues specifies maximum size of message queues to keep in memory + blocks_(create, destroy, storage, save, load), + queue_policy_(q_policy), + limit_(limit), + threads_(threads == -1 ? thread::hardware_concurrency() : threads), + storage_(storage), + // Communicator functionality + comm_(comm), + expected_(0), + exchange_round_(-1), + immediate_(true) + {} + ~Master() { set_immediate(true); clear(); delete queue_policy_; } + inline void clear(); + inline void destroy(int i) { if (blocks_.own()) blocks_.destroy(i); } + + inline int add(int gid, void* b, Link* l); //!< add a block + inline void* release(int i); //!< release ownership of the block + + //!< return the `i`-th block + inline void* block(int i) const { return blocks_.find(i); } + template + Block* block(int i) const { return static_cast(block(i)); } + inline Link* link(int i) const { return links_[i]; } + inline int loaded_block() const { return blocks_.available(); } + + inline void unload(int i); + inline void load(int i); + void unload(std::vector& loaded) { for(unsigned i = 0; i < loaded.size(); ++i) unload(loaded[i]); loaded.clear(); } + void unload_all() { for(unsigned i = 0; i < size(); ++i) if (block(i) != 0) unload(i); } + inline bool has_incoming(int i) const; + + inline void unload_queues(int i); + inline void unload_incoming(int gid); + inline void unload_outgoing(int gid); + inline void load_queues(int i); + inline void load_incoming(int gid); + inline void load_outgoing(int gid); + + //! return the MPI communicator + const mpi::communicator& communicator() const { return comm_; } + //! return the MPI communicator + mpi::communicator& communicator() { return comm_; } + + //! return the `i`-th block, loading it if necessary + void* get(int i) { return blocks_.get(i); } + //! return gid of the `i`-th block + int gid(int i) const { return gids_[i]; } + //! return the local id of the local block with global id gid, or -1 if not local + int lid(int gid) const { return local(gid) ? lids_.find(gid)->second : -1; } + //! whether the block with global id gid is local + bool local(int gid) const { return lids_.find(gid) != lids_.end(); } + + //! exchange the queues between all the blocks (collective operation) + inline void exchange(); + inline void process_collectives(); + + inline + ProxyWithLink proxy(int i) const; + + //! return the number of local blocks + unsigned size() const { return blocks_.size(); } + void* create() const { return blocks_.create(); } + + // accessors + int limit() const { return limit_; } + int threads() const { return threads_; } + int in_memory() const { return *blocks_.in_memory().const_access(); } + + void set_threads(int threads) { threads_ = threads; } + + CreateBlock creator() const { return blocks_.creator(); } + DestroyBlock destroyer() const { return blocks_.destroyer(); } + LoadBlock loader() const { return blocks_.loader(); } + SaveBlock saver() const { return blocks_.saver(); } + + //! call `f` with every block + template + void foreach_(const Callback& f, const Skip& s = NeverSkip()); + + template + void foreach(const F& f, const Skip& s = NeverSkip()) + { + using Block = typename detail::block_traits::type; + foreach_(f, s); + } + + inline void execute(); + + bool immediate() const { return immediate_; } + void set_immediate(bool i) { if (i && !immediate_) execute(); immediate_ = i; } + + public: + // Communicator functionality + IncomingQueues& incoming(int gid) { return incoming_[exchange_round_].map[gid].queues; } + OutgoingQueues& outgoing(int gid) { return outgoing_[gid].queues; } + CollectivesList& collectives(int gid) { return collectives_[gid]; } + size_t incoming_count(int gid) const + { + IncomingRoundMap::const_iterator round_it = incoming_.find(exchange_round_); + if (round_it == incoming_.end()) + return 0; + IncomingQueuesMap::const_iterator queue_it = round_it->second.map.find(gid); + if (queue_it == round_it->second.map.end()) + return 0; + return queue_it->second.queues.size(); + } + size_t outgoing_count(int gid) const { OutgoingQueuesMap::const_iterator it = outgoing_.find(gid); if (it == outgoing_.end()) return 0; return it->second.queues.size(); } + + void set_expected(int expected) { expected_ = expected; } + void add_expected(int i) { expected_ += i; } + int expected() const { return expected_; } + void replace_link(int i, Link* link) { expected_ -= links_[i]->size_unique(); delete links_[i]; links_[i] = link; expected_ += links_[i]->size_unique(); } + + public: + // Communicator functionality + inline void flush(); // makes sure all the serialized queues migrate to their target processors + + private: + // Communicator functionality + inline void comm_exchange(ToSendList& to_send, int out_queues_limit); // possibly called in between block computations + inline bool nudge(); + + void cancel_requests(); // TODO + + // debug + inline void show_incoming_records() const; + + private: + std::vector links_; + Collection blocks_; + std::vector gids_; + std::map lids_; + + QueuePolicy* queue_policy_; + + int limit_; + int threads_; + ExternalStorage* storage_; + + private: + // Communicator + mpi::communicator comm_; + IncomingRoundMap incoming_; + OutgoingQueuesMap outgoing_; + InFlightSendsList inflight_sends_; + InFlightRecvsMap inflight_recvs_; + CollectivesMap collectives_; + int expected_; + int exchange_round_; + bool immediate_; + Commands commands_; + + private: + fast_mutex add_mutex_; + + public: + std::shared_ptr log = get_logger(); + stats::Profiler prof; + }; + + struct Master::BaseCommand + { + virtual ~BaseCommand() {} // to delete derived classes + virtual void execute(void* b, const ProxyWithLink& cp) const =0; + virtual bool skip(int i, const Master& master) const =0; + }; + + template + struct Master::Command: public BaseCommand + { + Command(Callback f_, const Skip& s_): + f(f_), s(s_) {} + + void execute(void* b, const ProxyWithLink& cp) const override { f(static_cast(b), cp); } + bool skip(int i, const Master& m) const override { return s(i,m); } + + Callback f; + Skip s; + }; + + struct Master::SkipNoIncoming + { bool operator()(int i, const Master& master) const { return !master.has_incoming(i); } }; + + struct Master::Collective + { + Collective(): + cop_(0) {} + Collective(detail::CollectiveOp* cop): + cop_(cop) {} + // this copy constructor is very ugly, but need it to insert Collectives into a list + Collective(const Collective& other): + cop_(0) { swap(const_cast(other)); } + ~Collective() { delete cop_; } + + void init() { cop_->init(); } + void swap(Collective& other) { std::swap(cop_, other.cop_); } + void update(const Collective& other) { cop_->update(*other.cop_); } + void global(const mpi::communicator& c) { cop_->global(c); } + void copy_from(Collective& other) const { cop_->copy_from(*other.cop_); } + void result_out(void* x) const { cop_->result_out(x); } + + detail::CollectiveOp* cop_; + + private: + Collective& operator=(const Collective& other); + }; +} + +#include "proxy.hpp" + +// --- ProcessBlock --- +struct diy::Master::ProcessBlock +{ + ProcessBlock(Master& master_, + const std::deque& blocks_, + int local_limit_, + critical_resource& idx_): + master(master_), + blocks(blocks_), + local_limit(local_limit_), + idx(idx_) + {} + + void process() + { + master.log->debug("Processing with thread: {}", this_thread::get_id()); + + std::vector local; + do + { + int cur = (*idx.access())++; + + if ((size_t)cur >= blocks.size()) + return; + + int i = blocks[cur]; + if (master.block(i)) + { + if (local.size() == (size_t)local_limit) + master.unload(local); + local.push_back(i); + } + + master.log->debug("Processing block: {}", master.gid(i)); + + bool skip_block = true; + for (size_t cmd = 0; cmd < master.commands_.size(); ++cmd) + { + if (!master.commands_[cmd]->skip(i, master)) + { + skip_block = false; + break; + } + } + + IncomingQueuesMap ¤t_incoming = master.incoming_[master.exchange_round_].map; + if (skip_block) + { + if (master.block(i) == 0) + master.load_queues(i); // even though we are skipping the block, the queues might be necessary + + for (size_t cmd = 0; cmd < master.commands_.size(); ++cmd) + { + master.commands_[cmd]->execute(0, master.proxy(i)); // 0 signals that we are skipping the block (even if it's loaded) + + // no longer need them, so get rid of them, rather than risk reloading + current_incoming[master.gid(i)].queues.clear(); + current_incoming[master.gid(i)].records.clear(); + } + + if (master.block(i) == 0) + master.unload_queues(i); // even though we are skipping the block, the queues might be necessary + } + else + { + if (master.block(i) == 0) // block unloaded + { + if (local.size() == (size_t)local_limit) // reached the local limit + master.unload(local); + + master.load(i); + local.push_back(i); + } + + for (size_t cmd = 0; cmd < master.commands_.size(); ++cmd) + { + master.commands_[cmd]->execute(master.block(i), master.proxy(i)); + + // no longer need them, so get rid of them + current_incoming[master.gid(i)].queues.clear(); + current_incoming[master.gid(i)].records.clear(); + } + } + } while(true); + + // TODO: invoke opportunistic communication + // don't forget to adjust Master::exchange() + } + + static void run(void* bf) { static_cast(bf)->process(); } + + Master& master; + const std::deque& blocks; + int local_limit; + critical_resource& idx; +}; +// -------------------- + +void +diy::Master:: +clear() +{ + for (unsigned i = 0; i < size(); ++i) + delete links_[i]; + blocks_.clear(); + links_.clear(); + gids_.clear(); + lids_.clear(); + expected_ = 0; +} + +void +diy::Master:: +unload(int i) +{ + log->debug("Unloading block: {}", gid(i)); + + blocks_.unload(i); + unload_queues(i); +} + +void +diy::Master:: +unload_queues(int i) +{ + unload_incoming(gid(i)); + unload_outgoing(gid(i)); +} + +void +diy::Master:: +unload_incoming(int gid) +{ + for (IncomingRoundMap::iterator round_itr = incoming_.begin(); round_itr != incoming_.end(); ++round_itr) + { + IncomingQueuesMap::iterator qmap_itr = round_itr->second.map.find(gid); + if (qmap_itr == round_itr->second.map.end()) + { + continue; + } + IncomingQueuesRecords& in_qrs = qmap_itr->second; + for (InQueueRecords::iterator it = in_qrs.records.begin(); it != in_qrs.records.end(); ++it) + { + QueueRecord& qr = it->second; + if (queue_policy_->unload_incoming(*this, it->first, gid, qr.size)) + { + log->debug("Unloading queue: {} <- {}", gid, it->first); + qr.external = storage_->put(in_qrs.queues[it->first]); + } + } + } +} + +void +diy::Master:: +unload_outgoing(int gid) +{ + OutgoingQueuesRecord& out_qr = outgoing_[gid]; + + size_t out_queues_size = sizeof(size_t); // map size + size_t count = 0; + for (OutgoingQueues::iterator it = out_qr.queues.begin(); it != out_qr.queues.end(); ++it) + { + if (it->first.proc == comm_.rank()) continue; + + out_queues_size += sizeof(BlockID); // target + out_queues_size += sizeof(size_t); // buffer.position + out_queues_size += sizeof(size_t); // buffer.size + out_queues_size += it->second.size(); // buffer contents + ++count; + } + if (queue_policy_->unload_outgoing(*this, gid, out_queues_size - sizeof(size_t))) + { + log->debug("Unloading outgoing queues: {} -> ...; size = {}\n", gid, out_queues_size); + MemoryBuffer bb; bb.reserve(out_queues_size); + diy::save(bb, count); + + for (OutgoingQueues::iterator it = out_qr.queues.begin(); it != out_qr.queues.end();) + { + if (it->first.proc == comm_.rank()) + { + // treat as incoming + if (queue_policy_->unload_incoming(*this, gid, it->first.gid, it->second.size())) + { + QueueRecord& qr = out_qr.external_local[it->first]; + qr.size = it->second.size(); + qr.external = storage_->put(it->second); + + out_qr.queues.erase(it++); + continue; + } // else keep in memory + } else + { + diy::save(bb, it->first); + diy::save(bb, it->second); + + out_qr.queues.erase(it++); + continue; + } + ++it; + } + + // TODO: this mechanism could be adjusted for direct saving to disk + // (without intermediate binary buffer serialization) + out_qr.external = storage_->put(bb); + } +} + +void +diy::Master:: +load(int i) +{ + log->debug("Loading block: {}", gid(i)); + + blocks_.load(i); + load_queues(i); +} + +void +diy::Master:: +load_queues(int i) +{ + load_incoming(gid(i)); + load_outgoing(gid(i)); +} + +void +diy::Master:: +load_incoming(int gid) +{ + IncomingQueuesRecords& in_qrs = incoming_[exchange_round_].map[gid]; + for (InQueueRecords::iterator it = in_qrs.records.begin(); it != in_qrs.records.end(); ++it) + { + QueueRecord& qr = it->second; + if (qr.external != -1) + { + log->debug("Loading queue: {} <- {}", gid, it->first); + storage_->get(qr.external, in_qrs.queues[it->first]); + qr.external = -1; + } + } +} + +void +diy::Master:: +load_outgoing(int gid) +{ + // TODO: we could adjust this mechanism to read directly from storage, + // bypassing an intermediate MemoryBuffer + OutgoingQueuesRecord& out_qr = outgoing_[gid]; + if (out_qr.external != -1) + { + MemoryBuffer bb; + storage_->get(out_qr.external, bb); + out_qr.external = -1; + + size_t count; + diy::load(bb, count); + for (size_t i = 0; i < count; ++i) + { + BlockID to; + diy::load(bb, to); + diy::load(bb, out_qr.queues[to]); + } + } +} + +diy::Master::ProxyWithLink +diy::Master:: +proxy(int i) const +{ return ProxyWithLink(Proxy(const_cast(this), gid(i)), block(i), link(i)); } + + +int +diy::Master:: +add(int gid, void* b, Link* l) +{ + if (*blocks_.in_memory().const_access() == limit_) + unload_all(); + + lock_guard lock(add_mutex_); // allow to add blocks from multiple threads + + blocks_.add(b); + links_.push_back(l); + gids_.push_back(gid); + + int lid = gids_.size() - 1; + lids_[gid] = lid; + add_expected(l->size_unique()); // NB: at every iteration we expect a message from each unique neighbor + + return lid; +} + +void* +diy::Master:: +release(int i) +{ + void* b = blocks_.release(i); + delete link(i); links_[i] = 0; + lids_.erase(gid(i)); + return b; +} + +bool +diy::Master:: +has_incoming(int i) const +{ + const IncomingQueuesRecords& in_qrs = const_cast(*this).incoming_[exchange_round_].map[gid(i)]; + for (InQueueRecords::const_iterator it = in_qrs.records.begin(); it != in_qrs.records.end(); ++it) + { + const QueueRecord& qr = it->second; + if (qr.size != 0) + return true; + } + return false; +} + +template +void +diy::Master:: +foreach_(const Callback& f, const Skip& skip) +{ + auto scoped = prof.scoped("foreach"); + commands_.push_back(new Command(f, skip)); + + if (immediate()) + execute(); +} + +void +diy::Master:: +execute() +{ + log->debug("Entered execute()"); + auto scoped = prof.scoped("execute"); + //show_incoming_records(); + + // touch the outgoing and incoming queues as well as collectives to make sure they exist + for (unsigned i = 0; i < size(); ++i) + { + outgoing(gid(i)); + incoming(gid(i)); // implicitly touches queue records + collectives(gid(i)); + } + + if (commands_.empty()) + return; + + // Order the blocks, so the loaded ones come first + std::deque blocks; + for (unsigned i = 0; i < size(); ++i) + if (block(i) == 0) + blocks.push_back(i); + else + blocks.push_front(i); + + // don't use more threads than we can have blocks in memory + int num_threads; + int blocks_per_thread; + if (limit_ == -1) + { + num_threads = threads_; + blocks_per_thread = size(); + } + else + { + num_threads = std::min(threads_, limit_); + blocks_per_thread = limit_/num_threads; + } + + // idx is shared + critical_resource idx(0); + + typedef ProcessBlock BlockFunctor; + if (num_threads > 1) + { + // launch the threads + typedef std::pair ThreadFunctorPair; + typedef std::list ThreadFunctorList; + ThreadFunctorList threads; + for (unsigned i = 0; i < (unsigned)num_threads; ++i) + { + BlockFunctor* bf = new BlockFunctor(*this, blocks, blocks_per_thread, idx); + threads.push_back(ThreadFunctorPair(new thread(&BlockFunctor::run, bf), bf)); + } + + // join the threads + for(ThreadFunctorList::iterator it = threads.begin(); it != threads.end(); ++it) + { + thread* t = it->first; + BlockFunctor* bf = it->second; + t->join(); + delete t; + delete bf; + } + } else + { + BlockFunctor bf(*this, blocks, blocks_per_thread, idx); + BlockFunctor::run(&bf); + } + + // clear incoming queues + incoming_[exchange_round_].map.clear(); + + if (limit() != -1 && in_memory() > limit()) + throw std::runtime_error(fmt::format("Fatal: {} blocks in memory, with limit {}", in_memory(), limit())); + + // clear commands + for (size_t i = 0; i < commands_.size(); ++i) + delete commands_[i]; + commands_.clear(); +} + +void +diy::Master:: +exchange() +{ + auto scoped = prof.scoped("exchange"); + execute(); + + log->debug("Starting exchange"); + + // make sure there is a queue for each neighbor + for (int i = 0; i < (int)size(); ++i) + { + OutgoingQueues& outgoing_queues = outgoing_[gid(i)].queues; + OutQueueRecords& external_local = outgoing_[gid(i)].external_local; + if (outgoing_queues.size() < (size_t)link(i)->size()) + for (unsigned j = 0; j < (unsigned)link(i)->size(); ++j) + { + if (external_local.find(link(i)->target(j)) == external_local.end()) + outgoing_queues[link(i)->target(j)]; // touch the outgoing queue, creating it if necessary + } + } + + flush(); + log->debug("Finished exchange"); +} + +namespace diy +{ +namespace detail +{ + template + struct VectorWindow + { + T *begin; + size_t count; + }; +} // namespace detail + +namespace mpi +{ +namespace detail +{ + template struct is_mpi_datatype< diy::detail::VectorWindow > { typedef true_type type; }; + + template + struct mpi_datatype< diy::detail::VectorWindow > + { + typedef diy::detail::VectorWindow VecWin; + static MPI_Datatype datatype() { return get_mpi_datatype(); } + static const void* address(const VecWin& x) { return x.begin; } + static void* address(VecWin& x) { return x.begin; } + static int count(const VecWin& x) { return static_cast(x.count); } + }; +} +} // namespace mpi::detail + +} // namespace diy + +/* Communicator */ +void +diy::Master:: +comm_exchange(ToSendList& to_send, int out_queues_limit) +{ + static const size_t MAX_MPI_MESSAGE_COUNT = INT_MAX; + + IncomingRound ¤t_incoming = incoming_[exchange_round_]; + // isend outgoing queues, up to the out_queues_limit + while(inflight_sends_.size() < (size_t)out_queues_limit && !to_send.empty()) + { + int from = to_send.front(); + + // deal with external_local queues + for (OutQueueRecords::iterator it = outgoing_[from].external_local.begin(); it != outgoing_[from].external_local.end(); ++it) + { + int to = it->first.gid; + + log->debug("Processing local queue: {} <- {} of size {}", to, from, it->second.size); + + QueueRecord& in_qr = current_incoming.map[to].records[from]; + bool in_external = block(lid(to)) == 0; + + if (in_external) + in_qr = it->second; + else + { + // load the queue + in_qr.size = it->second.size; + in_qr.external = -1; + + MemoryBuffer bb; + storage_->get(it->second.external, bb); + + current_incoming.map[to].queues[from].swap(bb); + } + ++current_incoming.received; + } + outgoing_[from].external_local.clear(); + + if (outgoing_[from].external != -1) + load_outgoing(from); + to_send.pop_front(); + + OutgoingQueues& outgoing = outgoing_[from].queues; + for (OutgoingQueues::iterator it = outgoing.begin(); it != outgoing.end(); ++it) + { + BlockID to_proc = it->first; + int to = to_proc.gid; + int proc = to_proc.proc; + + log->debug("Processing queue: {} <- {} of size {}", to, from, outgoing_[from].queues[to_proc].size()); + + // There may be local outgoing queues that remained in memory + if (proc == comm_.rank()) // sending to ourselves: simply swap buffers + { + log->debug("Moving queue in-place: {} <- {}", to, from); + + QueueRecord& in_qr = current_incoming.map[to].records[from]; + bool in_external = block(lid(to)) == 0; + if (in_external) + { + log->debug("Unloading outgoing directly as incoming: {} <- {}", to, from); + MemoryBuffer& bb = it->second; + in_qr.size = bb.size(); + if (queue_policy_->unload_incoming(*this, from, to, in_qr.size)) + in_qr.external = storage_->put(bb); + else + { + MemoryBuffer& in_bb = current_incoming.map[to].queues[from]; + in_bb.swap(bb); + in_bb.reset(); + in_qr.external = -1; + } + } else // !in_external + { + log->debug("Swapping in memory: {} <- {}", to, from); + MemoryBuffer& bb = current_incoming.map[to].queues[from]; + bb.swap(it->second); + bb.reset(); + in_qr.size = bb.size(); + in_qr.external = -1; + } + + ++current_incoming.received; + continue; + } + + std::shared_ptr buffer = std::make_shared(); + buffer->swap(it->second); + + MessageInfo info{from, to, exchange_round_}; + if (buffer->size() <= (MAX_MPI_MESSAGE_COUNT - sizeof(info))) + { + diy::save(*buffer, info); + + inflight_sends_.emplace_back(); + inflight_sends_.back().info = info; + inflight_sends_.back().request = comm_.isend(proc, tags::queue, buffer->buffer); + inflight_sends_.back().message = buffer; + } + else + { + int npieces = static_cast((buffer->size() + MAX_MPI_MESSAGE_COUNT - 1)/MAX_MPI_MESSAGE_COUNT); + + // first send the head + std::shared_ptr hb = std::make_shared(); + diy::save(*hb, buffer->size()); + diy::save(*hb, info); + + inflight_sends_.emplace_back(); + inflight_sends_.back().info = info; + inflight_sends_.back().request = comm_.isend(proc, tags::piece, hb->buffer); + inflight_sends_.back().message = hb; + + // send the message pieces + size_t msg_buff_idx = 0; + for (int i = 0; i < npieces; ++i, msg_buff_idx += MAX_MPI_MESSAGE_COUNT) + { + int tag = (i == (npieces - 1)) ? tags::queue : tags::piece; + + detail::VectorWindow window; + window.begin = &buffer->buffer[msg_buff_idx]; + window.count = std::min(MAX_MPI_MESSAGE_COUNT, buffer->size() - msg_buff_idx); + + inflight_sends_.emplace_back(); + inflight_sends_.back().info = info; + inflight_sends_.back().request = comm_.isend(proc, tag, window); + inflight_sends_.back().message = buffer; + } + } + } + } + + // kick requests + while(nudge()); + + // check incoming queues + mpi::optional ostatus = comm_.iprobe(mpi::any_source, mpi::any_tag); + while(ostatus) + { + InFlightRecv &ir = inflight_recvs_[ostatus->source()]; + + if (ir.info.from == -1) // uninitialized + { + MemoryBuffer bb; + comm_.recv(ostatus->source(), ostatus->tag(), bb.buffer); + + if (ostatus->tag() == tags::piece) + { + size_t msg_size; + diy::load(bb, msg_size); + diy::load(bb, ir.info); + + ir.message.buffer.reserve(msg_size); + } + else // tags::queue + { + diy::load_back(bb, ir.info); + ir.message.swap(bb); + } + } + else + { + size_t start_idx = ir.message.buffer.size(); + size_t count = ostatus->count(); + ir.message.buffer.resize(start_idx + count); + + detail::VectorWindow window; + window.begin = &ir.message.buffer[start_idx]; + window.count = count; + + comm_.recv(ostatus->source(), ostatus->tag(), window); + } + + if (ostatus->tag() == tags::queue) + { + size_t size = ir.message.size(); + int from = ir.info.from; + int to = ir.info.to; + int external = -1; + + assert(ir.info.round >= exchange_round_); + IncomingRound *in = &incoming_[ir.info.round]; + + bool unload_queue = ((ir.info.round == exchange_round_) ? (block(lid(to)) == 0) : (limit_ != -1)) && + queue_policy_->unload_incoming(*this, from, to, size); + if (unload_queue) + { + log->debug("Directly unloading queue {} <- {}", to, from); + external = storage_->put(ir.message); // unload directly + } + else + { + in->map[to].queues[from].swap(ir.message); + in->map[to].queues[from].reset(); // buffer position = 0 + } + in->map[to].records[from] = QueueRecord(size, external); + + ++(in->received); + ir = InFlightRecv(); // reset + } + + ostatus = comm_.iprobe(mpi::any_source, mpi::any_tag); + } +} + +void +diy::Master:: +flush() +{ + + auto scoped = prof.scoped("comm"); +#ifdef DEBUG + time_type start = get_time(); + unsigned wait = 1; +#endif + + // prepare for next round + incoming_.erase(exchange_round_); + ++exchange_round_; + + // make a list of outgoing queues to send (the ones in memory come first) + ToSendList to_send; + for (OutgoingQueuesMap::iterator it = outgoing_.begin(); it != outgoing_.end(); ++it) + { + OutgoingQueuesRecord& out = it->second; + if (out.external == -1) + to_send.push_front(it->first); + else + to_send.push_back(it->first); + } + log->debug("to_send.size(): {}", to_send.size()); + + // XXX: we probably want a cleverer limit than block limit times average number of queues per block + // XXX: with queues we could easily maintain a specific space limit + int out_queues_limit; + if (limit_ == -1 || size() == 0) + out_queues_limit = to_send.size(); + else + out_queues_limit = std::max((size_t) 1, to_send.size()/size()*limit_); // average number of queues per block * in-memory block limit + + do + { + comm_exchange(to_send, out_queues_limit); + +#ifdef DEBUG + time_type cur = get_time(); + if (cur - start > wait*1000) + { + log->warn("Waiting in flush [{}]: {} - {} out of {}", + comm_.rank(), inflight_sends_.size(), incoming_[exchange_round_].received, expected_); + wait *= 2; + } +#endif + } while (!inflight_sends_.empty() || incoming_[exchange_round_].received < expected_ || !to_send.empty()); + + outgoing_.clear(); + + log->debug("Done in flush"); + //show_incoming_records(); + + process_collectives(); +} + +void +diy::Master:: +process_collectives() +{ + auto scoped = prof.scoped("collectives"); + + if (collectives_.empty()) + return; + + typedef CollectivesList::iterator CollectivesIterator; + std::vector iters; + std::vector gids; + for (CollectivesMap::iterator cur = collectives_.begin(); cur != collectives_.end(); ++cur) + { + gids.push_back(cur->first); + iters.push_back(cur->second.begin()); + } + + while (iters[0] != collectives_.begin()->second.end()) + { + iters[0]->init(); + for (unsigned j = 1; j < iters.size(); ++j) + { + // NB: this assumes that the operations are commutative + iters[0]->update(*iters[j]); + } + iters[0]->global(comm_); // do the mpi collective + + for (unsigned j = 1; j < iters.size(); ++j) + { + iters[j]->copy_from(*iters[0]); + ++iters[j]; + } + + ++iters[0]; + } +} + +bool +diy::Master:: +nudge() +{ + bool success = false; + for (InFlightSendsList::iterator it = inflight_sends_.begin(); it != inflight_sends_.end(); ++it) + { + mpi::optional ostatus = it->request.test(); + if (ostatus) + { + success = true; + InFlightSendsList::iterator rm = it; + --it; + inflight_sends_.erase(rm); + } + } + return success; +} + +void +diy::Master:: +show_incoming_records() const +{ + for (IncomingRoundMap::const_iterator rounds_itr = incoming_.begin(); rounds_itr != incoming_.end(); ++rounds_itr) + { + for (IncomingQueuesMap::const_iterator it = rounds_itr->second.map.begin(); it != rounds_itr->second.map.end(); ++it) + { + const IncomingQueuesRecords& in_qrs = it->second; + for (InQueueRecords::const_iterator cur = in_qrs.records.begin(); cur != in_qrs.records.end(); ++cur) + { + const QueueRecord& qr = cur->second; + log->info("round: {}, {} <- {}: (size,external) = ({},{})", + rounds_itr->first, + it->first, cur->first, + qr.size, + qr.external); + } + for (IncomingQueues::const_iterator cur = in_qrs.queues.begin(); cur != in_qrs.queues.end(); ++cur) + { + log->info("round: {}, {} <- {}: queue.size() = {}", + rounds_itr->first, + it->first, cur->first, + const_cast(in_qrs).queues[cur->first].size()); + } + } + } +} + +#endif diff --git a/diy/include/diy/mpi.hpp b/diy/include/diy/mpi.hpp new file mode 100644 index 000000000..28502002f --- /dev/null +++ b/diy/include/diy/mpi.hpp @@ -0,0 +1,32 @@ +#ifndef DIY_MPI_HPP +#define DIY_MPI_HPP + +#include + +#include "mpi/constants.hpp" +#include "mpi/datatypes.hpp" +#include "mpi/optional.hpp" +#include "mpi/status.hpp" +#include "mpi/request.hpp" +#include "mpi/point-to-point.hpp" +#include "mpi/communicator.hpp" +#include "mpi/collectives.hpp" +#include "mpi/io.hpp" + +namespace diy +{ +namespace mpi +{ + +//! \ingroup MPI +struct environment +{ + environment() { int argc = 0; char** argv; MPI_Init(&argc, &argv); } + environment(int argc, char* argv[]) { MPI_Init(&argc, &argv); } + ~environment() { MPI_Finalize(); } +}; + +} +} + +#endif diff --git a/diy/include/diy/mpi/collectives.hpp b/diy/include/diy/mpi/collectives.hpp new file mode 100644 index 000000000..4324534e5 --- /dev/null +++ b/diy/include/diy/mpi/collectives.hpp @@ -0,0 +1,328 @@ +#include + +#include "operations.hpp" + +namespace diy +{ +namespace mpi +{ + //!\addtogroup MPI + //!@{ + + template + struct Collectives + { + typedef detail::mpi_datatype Datatype; + + static void broadcast(const communicator& comm, T& x, int root) + { + MPI_Bcast(Datatype::address(x), + Datatype::count(x), + Datatype::datatype(), root, comm); + } + + static void broadcast(const communicator& comm, std::vector& x, int root) + { + size_t sz = x.size(); + Collectives::broadcast(comm, sz, root); + + if (comm.rank() != root) + x.resize(sz); + + MPI_Bcast(Datatype::address(x[0]), + x.size(), + Datatype::datatype(), root, comm); + } + + static request ibroadcast(const communicator& comm, T& x, int root) + { + request r; + MPI_Ibcast(Datatype::address(x), + Datatype::count(x), + Datatype::datatype(), root, comm, &r.r); + return r; + } + + static void gather(const communicator& comm, const T& in, std::vector& out, int root) + { + size_t s = comm.size(); + s *= Datatype::count(in); + out.resize(s); + MPI_Gather(Datatype::address(const_cast(in)), + Datatype::count(in), + Datatype::datatype(), + Datatype::address(out[0]), + Datatype::count(in), + Datatype::datatype(), + root, comm); + } + + static void gather(const communicator& comm, const std::vector& in, std::vector< std::vector >& out, int root) + { + std::vector counts(comm.size()); + Collectives::gather(comm, (int) in.size(), counts, root); + + std::vector offsets(comm.size(), 0); + for (unsigned i = 1; i < offsets.size(); ++i) + offsets[i] = offsets[i-1] + counts[i-1]; + + std::vector buffer(offsets.back() + counts.back()); + MPI_Gatherv(Datatype::address(const_cast(in[0])), + in.size(), + Datatype::datatype(), + Datatype::address(buffer[0]), + &counts[0], + &offsets[0], + Datatype::datatype(), + root, comm); + + out.resize(comm.size()); + size_t cur = 0; + for (unsigned i = 0; i < (unsigned)comm.size(); ++i) + { + out[i].reserve(counts[i]); + for (unsigned j = 0; j < (unsigned)counts[i]; ++j) + out[i].push_back(buffer[cur++]); + } + } + + static void gather(const communicator& comm, const T& in, int root) + { + MPI_Gather(Datatype::address(const_cast(in)), + Datatype::count(in), + Datatype::datatype(), + Datatype::address(const_cast(in)), + Datatype::count(in), + Datatype::datatype(), + root, comm); + } + + static void gather(const communicator& comm, const std::vector& in, int root) + { + Collectives::gather(comm, (int) in.size(), root); + + MPI_Gatherv(Datatype::address(const_cast(in[0])), + in.size(), + Datatype::datatype(), + 0, 0, 0, + Datatype::datatype(), + root, comm); + } + + static void all_gather(const communicator& comm, const T& in, std::vector& out) + { + size_t s = comm.size(); + s *= Datatype::count(in); + out.resize(s); + MPI_Allgather(Datatype::address(const_cast(in)), + Datatype::count(in), + Datatype::datatype(), + Datatype::address(out[0]), + Datatype::count(in), + Datatype::datatype(), + comm); + } + + static void all_gather(const communicator& comm, const std::vector& in, std::vector< std::vector >& out) + { + std::vector counts(comm.size()); + Collectives::all_gather(comm, (int) in.size(), counts); + + std::vector offsets(comm.size(), 0); + for (unsigned i = 1; i < offsets.size(); ++i) + offsets[i] = offsets[i-1] + counts[i-1]; + + std::vector buffer(offsets.back() + counts.back()); + MPI_Allgatherv(Datatype::address(const_cast(in[0])), + in.size(), + Datatype::datatype(), + Datatype::address(buffer[0]), + &counts[0], + &offsets[0], + Datatype::datatype(), + comm); + + out.resize(comm.size()); + size_t cur = 0; + for (int i = 0; i < comm.size(); ++i) + { + out[i].reserve(counts[i]); + for (int j = 0; j < counts[i]; ++j) + out[i].push_back(buffer[cur++]); + } + } + + static void reduce(const communicator& comm, const T& in, T& out, int root, const Op& op) + { + MPI_Reduce(Datatype::address(const_cast(in)), + Datatype::address(out), + Datatype::count(in), + Datatype::datatype(), + detail::mpi_op::get(op), + root, comm); + } + + static void reduce(const communicator& comm, const T& in, int root, const Op& op) + { + MPI_Reduce(Datatype::address(const_cast(in)), + Datatype::address(const_cast(in)), + Datatype::count(in), + Datatype::datatype(), + detail::mpi_op::get(op), + root, comm); + } + + static void all_reduce(const communicator& comm, const T& in, T& out, const Op& op) + { + MPI_Allreduce(Datatype::address(const_cast(in)), + Datatype::address(out), + Datatype::count(in), + Datatype::datatype(), + detail::mpi_op::get(op), + comm); + } + + static void all_reduce(const communicator& comm, const std::vector& in, std::vector& out, const Op& op) + { + out.resize(in.size()); + MPI_Allreduce(Datatype::address(const_cast(in[0])), + Datatype::address(out[0]), + in.size(), + Datatype::datatype(), + detail::mpi_op::get(op), + comm); + } + + static void scan(const communicator& comm, const T& in, T& out, const Op& op) + { + MPI_Scan(Datatype::address(const_cast(in)), + Datatype::address(out), + Datatype::count(in), + Datatype::datatype(), + detail::mpi_op::get(op), + comm); + } + + static void all_to_all(const communicator& comm, const std::vector& in, std::vector& out, int n = 1) + { + // NB: this will fail if T is a vector + MPI_Alltoall(Datatype::address(const_cast(in[0])), n, + Datatype::datatype(), + Datatype::address(out[0]), n, + Datatype::datatype(), + comm); + } + }; + + //! Broadcast to all processes in `comm`. + template + void broadcast(const communicator& comm, T& x, int root) + { + Collectives::broadcast(comm, x, root); + } + + //! Broadcast for vectors + template + void broadcast(const communicator& comm, std::vector& x, int root) + { + Collectives::broadcast(comm, x, root); + } + + //! iBroadcast to all processes in `comm`. + template + request ibroadcast(const communicator& comm, T& x, int root) + { + return Collectives::ibroadcast(comm, x, root); + } + + //! Gather from all processes in `comm`. + //! On `root` process, `out` is resized to `comm.size()` and filled with + //! elements from the respective ranks. + template + void gather(const communicator& comm, const T& in, std::vector& out, int root) + { + Collectives::gather(comm, in, out, root); + } + + //! Same as above, but for vectors. + template + void gather(const communicator& comm, const std::vector& in, std::vector< std::vector >& out, int root) + { + Collectives::gather(comm, in, out, root); + } + + //! Simplified version (without `out`) for use on non-root processes. + template + void gather(const communicator& comm, const T& in, int root) + { + Collectives::gather(comm, in, root); + } + + //! Simplified version (without `out`) for use on non-root processes. + template + void gather(const communicator& comm, const std::vector& in, int root) + { + Collectives::gather(comm, in, root); + } + + //! all_gather from all processes in `comm`. + //! `out` is resized to `comm.size()` and filled with + //! elements from the respective ranks. + template + void all_gather(const communicator& comm, const T& in, std::vector& out) + { + Collectives::all_gather(comm, in, out); + } + + //! Same as above, but for vectors. + template + void all_gather(const communicator& comm, const std::vector& in, std::vector< std::vector >& out) + { + Collectives::all_gather(comm, in, out); + } + + //! reduce + template + void reduce(const communicator& comm, const T& in, T& out, int root, const Op& op) + { + Collectives::reduce(comm, in, out, root, op); + } + + //! Simplified version (without `out`) for use on non-root processes. + template + void reduce(const communicator& comm, const T& in, int root, const Op& op) + { + Collectives::reduce(comm, in, root, op); + } + + //! all_reduce + template + void all_reduce(const communicator& comm, const T& in, T& out, const Op& op) + { + Collectives::all_reduce(comm, in, out, op); + } + + //! Same as above, but for vectors. + template + void all_reduce(const communicator& comm, const std::vector& in, std::vector& out, const Op& op) + { + Collectives::all_reduce(comm, in, out, op); + } + + //! scan + template + void scan(const communicator& comm, const T& in, T& out, const Op& op) + { + Collectives::scan(comm, in, out, op); + } + + //! all_to_all + template + void all_to_all(const communicator& comm, const std::vector& in, std::vector& out, int n = 1) + { + Collectives::all_to_all(comm, in, out, n); + } + + //!@} +} +} diff --git a/diy/include/diy/mpi/communicator.hpp b/diy/include/diy/mpi/communicator.hpp new file mode 100644 index 000000000..d1bdf33f7 --- /dev/null +++ b/diy/include/diy/mpi/communicator.hpp @@ -0,0 +1,72 @@ +namespace diy +{ +namespace mpi +{ + + //! \ingroup MPI + //! Simple wrapper around `MPI_Comm`. + class communicator + { + public: + communicator(MPI_Comm comm = MPI_COMM_WORLD): + comm_(comm), rank_(0), size_(1) { if (comm != MPI_COMM_NULL) { MPI_Comm_rank(comm_, &rank_); MPI_Comm_size(comm_, &size_); } } + + int rank() const { return rank_; } + int size() const { return size_; } + + //void send(int dest, + // int tag, + // const void* buf, + // MPI_Datatype datatype) const { } + + //! Send `x` to processor `dest` using `tag` (blocking). + template + void send(int dest, int tag, const T& x) const { detail::send()(comm_, dest, tag, x); } + + //! Receive `x` from `dest` using `tag` (blocking). + //! If `T` is an `std::vector<...>`, `recv` will resize it to fit exactly the sent number of values. + template + status recv(int source, int tag, T& x) const { return detail::recv()(comm_, source, tag, x); } + + //! Non-blocking version of `send()`. + template + request isend(int dest, int tag, const T& x) const { return detail::isend()(comm_, dest, tag, x); } + + //! Non-blocking version of `recv()`. + //! If `T` is an `std::vector<...>`, its size must be big enough to accomodate the sent values. + template + request irecv(int source, int tag, T& x) const { return detail::irecv()(comm_, source, tag, x); } + + //! probe + status probe(int source, int tag) const { status s; MPI_Probe(source, tag, comm_, &s.s); return s; } + + //! iprobe + inline + optional + iprobe(int source, int tag) const; + + //! barrier + void barrier() const { MPI_Barrier(comm_); } + + operator MPI_Comm() const { return comm_; } + + private: + MPI_Comm comm_; + int rank_; + int size_; + }; +} +} + +diy::mpi::optional +diy::mpi::communicator:: +iprobe(int source, int tag) const +{ + status s; + int flag; + MPI_Iprobe(source, tag, comm_, &flag, &s.s); + if (flag) + return s; + return optional(); +} + diff --git a/diy/include/diy/mpi/constants.hpp b/diy/include/diy/mpi/constants.hpp new file mode 100644 index 000000000..7668e418f --- /dev/null +++ b/diy/include/diy/mpi/constants.hpp @@ -0,0 +1,13 @@ +#ifndef DIY_MPI_CONSTANTS_HPP +#define DIY_MPI_CONSTANTS_HPP + +namespace diy +{ +namespace mpi +{ + const int any_source = MPI_ANY_SOURCE; + const int any_tag = MPI_ANY_TAG; +} +} + +#endif diff --git a/diy/include/diy/mpi/datatypes.hpp b/diy/include/diy/mpi/datatypes.hpp new file mode 100644 index 000000000..7d8e3a448 --- /dev/null +++ b/diy/include/diy/mpi/datatypes.hpp @@ -0,0 +1,63 @@ +#ifndef DIY_MPI_DATATYPES_HPP +#define DIY_MPI_DATATYPES_HPP + +#include + +namespace diy +{ +namespace mpi +{ +namespace detail +{ + template MPI_Datatype get_mpi_datatype(); + + struct true_type {}; + struct false_type {}; + + /* is_mpi_datatype */ + template + struct is_mpi_datatype { typedef false_type type; }; + +#define DIY_MPI_DATATYPE_MAP(cpp_type, mpi_type) \ + template<> inline MPI_Datatype get_mpi_datatype() { return mpi_type; } \ + template<> struct is_mpi_datatype { typedef true_type type; }; \ + template<> struct is_mpi_datatype< std::vector > { typedef true_type type; }; + + DIY_MPI_DATATYPE_MAP(char, MPI_BYTE); + DIY_MPI_DATATYPE_MAP(unsigned char, MPI_BYTE); + DIY_MPI_DATATYPE_MAP(bool, MPI_BYTE); + DIY_MPI_DATATYPE_MAP(int, MPI_INT); + DIY_MPI_DATATYPE_MAP(unsigned, MPI_UNSIGNED); + DIY_MPI_DATATYPE_MAP(long, MPI_LONG); + DIY_MPI_DATATYPE_MAP(unsigned long, MPI_UNSIGNED_LONG); + DIY_MPI_DATATYPE_MAP(long long, MPI_LONG_LONG_INT); + DIY_MPI_DATATYPE_MAP(unsigned long long, MPI_UNSIGNED_LONG_LONG); + DIY_MPI_DATATYPE_MAP(float, MPI_FLOAT); + DIY_MPI_DATATYPE_MAP(double, MPI_DOUBLE); + + /* mpi_datatype: helper routines, specialized for std::vector<...> */ + template + struct mpi_datatype + { + static MPI_Datatype datatype() { return get_mpi_datatype(); } + static const void* address(const T& x) { return &x; } + static void* address(T& x) { return &x; } + static int count(const T& x) { return 1; } + }; + + template + struct mpi_datatype< std::vector > + { + typedef std::vector VecU; + + static MPI_Datatype datatype() { return get_mpi_datatype(); } + static const void* address(const VecU& x) { return &x[0]; } + static void* address(VecU& x) { return &x[0]; } + static int count(const VecU& x) { return x.size(); } + }; + +} +} +} + +#endif diff --git a/diy/include/diy/mpi/io.hpp b/diy/include/diy/mpi/io.hpp new file mode 100644 index 000000000..ebe6a2e17 --- /dev/null +++ b/diy/include/diy/mpi/io.hpp @@ -0,0 +1,137 @@ +#ifndef DIY_MPI_IO_HPP +#define DIY_MPI_IO_HPP + +#include +#include + +namespace diy +{ +namespace mpi +{ +namespace io +{ + typedef MPI_Offset offset; + + //! Wraps MPI file IO. \ingroup MPI + class file + { + public: + enum + { + rdonly = MPI_MODE_RDONLY, + rdwr = MPI_MODE_RDWR, + wronly = MPI_MODE_WRONLY, + create = MPI_MODE_CREATE, + exclusive = MPI_MODE_EXCL, + delete_on_close = MPI_MODE_DELETE_ON_CLOSE, + unique_open = MPI_MODE_UNIQUE_OPEN, + sequential = MPI_MODE_SEQUENTIAL, + append = MPI_MODE_APPEND + }; + + public: + file(const communicator& comm, + const std::string& filename, + int mode): + comm_(comm) { MPI_File_open(comm, const_cast(filename.c_str()), mode, MPI_INFO_NULL, &fh); } + ~file() { close(); } + void close() { if (fh != MPI_FILE_NULL) MPI_File_close(&fh); } + + offset size() const { offset sz; MPI_File_get_size(fh, &sz); return sz; } + void resize(offset size) { MPI_File_set_size(fh, size); } + + inline void read_at(offset o, char* buffer, size_t size); + inline void read_at_all(offset o, char* buffer, size_t size); + inline void write_at(offset o, const char* buffer, size_t size); + inline void write_at_all(offset o, const char* buffer, size_t size); + + template + inline void read_at(offset o, std::vector& data); + + template + inline void read_at_all(offset o, std::vector& data); + + template + inline void write_at(offset o, const std::vector& data); + + template + inline void write_at_all(offset o, const std::vector& data); + + const communicator& + comm() const { return comm_; } + + MPI_File& handle() { return fh; } + + private: + const communicator& comm_; + MPI_File fh; + }; +} +} +} + +void +diy::mpi::io::file:: +read_at(offset o, char* buffer, size_t size) +{ + status s; + MPI_File_read_at(fh, o, buffer, size, detail::get_mpi_datatype(), &s.s); +} + +template +void +diy::mpi::io::file:: +read_at(offset o, std::vector& data) +{ + read_at(o, &data[0], data.size()*sizeof(T)); +} + +void +diy::mpi::io::file:: +read_at_all(offset o, char* buffer, size_t size) +{ + status s; + MPI_File_read_at_all(fh, o, buffer, size, detail::get_mpi_datatype(), &s.s); +} + +template +void +diy::mpi::io::file:: +read_at_all(offset o, std::vector& data) +{ + read_at_all(o, (char*) &data[0], data.size()*sizeof(T)); +} + +void +diy::mpi::io::file:: +write_at(offset o, const char* buffer, size_t size) +{ + status s; + MPI_File_write_at(fh, o, (void *)buffer, size, detail::get_mpi_datatype(), &s.s); +} + +template +void +diy::mpi::io::file:: +write_at(offset o, const std::vector& data) +{ + write_at(o, (const char*) &data[0], data.size()*sizeof(T)); +} + +void +diy::mpi::io::file:: +write_at_all(offset o, const char* buffer, size_t size) +{ + status s; + MPI_File_write_at_all(fh, o, (void *)buffer, size, detail::get_mpi_datatype(), &s.s); +} + +template +void +diy::mpi::io::file:: +write_at_all(offset o, const std::vector& data) +{ + write_at_all(o, &data[0], data.size()*sizeof(T)); +} + +#endif diff --git a/diy/include/diy/mpi/operations.hpp b/diy/include/diy/mpi/operations.hpp new file mode 100644 index 000000000..9c38e58ae --- /dev/null +++ b/diy/include/diy/mpi/operations.hpp @@ -0,0 +1,26 @@ +#include + +namespace diy +{ +namespace mpi +{ + //! \addtogroup MPI + //!@{ + template + struct maximum { const U& operator()(const U& x, const U& y) const { return std::max(x,y); } }; + template + struct minimum { const U& operator()(const U& x, const U& y) const { return std::min(x,y); } }; + //!@} + +namespace detail +{ + template struct mpi_op { static MPI_Op get(const T&); }; + template struct mpi_op< maximum > { static MPI_Op get(const maximum&) { return MPI_MAX; } }; + template struct mpi_op< minimum > { static MPI_Op get(const minimum&) { return MPI_MIN; } }; + template struct mpi_op< std::plus > { static MPI_Op get(const std::plus&) { return MPI_SUM; } }; + template struct mpi_op< std::multiplies > { static MPI_Op get(const std::multiplies&) { return MPI_PROD; } }; + template struct mpi_op< std::logical_and > { static MPI_Op get(const std::logical_and&) { return MPI_LAND; } }; + template struct mpi_op< std::logical_or > { static MPI_Op get(const std::logical_or&) { return MPI_LOR; } }; +} +} +} diff --git a/diy/include/diy/mpi/optional.hpp b/diy/include/diy/mpi/optional.hpp new file mode 100644 index 000000000..ab58aaf81 --- /dev/null +++ b/diy/include/diy/mpi/optional.hpp @@ -0,0 +1,55 @@ +namespace diy +{ +namespace mpi +{ + template + struct optional + { + optional(): + init_(false) {} + + optional(const T& v): + init_(true) { new(buf_) T(v); } + + optional(const optional& o): + init_(o.init_) { if (init_) new(buf_) T(*o); } + + ~optional() { if (init_) clear(); } + + inline + optional& operator=(const optional& o); + + operator bool() const { return init_; } + + T& operator*() { return *static_cast(address()); } + const T& operator*() const { return *static_cast(address()); } + + T* operator->() { return &(operator*()); } + const T* operator->() const { return &(operator*()); } + + private: + void clear() { static_cast(address())->~T(); } + + void* address() { return buf_; } + const void* address() const { return buf_; } + + private: + bool init_; + char buf_[sizeof(T)]; + }; +} +} + +template +diy::mpi::optional& +diy::mpi::optional:: +operator=(const optional& o) +{ + if (init_) + clear(); + init_ = o.init_; + if (init_) + new (buf_) T(*o); + + return *this; +} diff --git a/diy/include/diy/mpi/point-to-point.hpp b/diy/include/diy/mpi/point-to-point.hpp new file mode 100644 index 000000000..dc8a341dc --- /dev/null +++ b/diy/include/diy/mpi/point-to-point.hpp @@ -0,0 +1,98 @@ +#include + +namespace diy +{ +namespace mpi +{ +namespace detail +{ + // send + template< class T, class is_mpi_datatype_ = typename is_mpi_datatype::type > + struct send; + + template + struct send + { + void operator()(MPI_Comm comm, int dest, int tag, const T& x) const + { + typedef mpi_datatype Datatype; + MPI_Send((void*) Datatype::address(x), + Datatype::count(x), + Datatype::datatype(), + dest, tag, comm); + } + }; + + // recv + template< class T, class is_mpi_datatype_ = typename is_mpi_datatype::type > + struct recv; + + template + struct recv + { + status operator()(MPI_Comm comm, int source, int tag, T& x) const + { + typedef mpi_datatype Datatype; + status s; + MPI_Recv((void*) Datatype::address(x), + Datatype::count(x), + Datatype::datatype(), + source, tag, comm, &s.s); + return s; + } + }; + + template + struct recv, true_type> + { + status operator()(MPI_Comm comm, int source, int tag, std::vector& x) const + { + status s; + + MPI_Probe(source, tag, comm, &s.s); + x.resize(s.count()); + MPI_Recv(&x[0], x.size(), get_mpi_datatype(), source, tag, comm, &s.s); + return s; + } + }; + + // isend + template< class T, class is_mpi_datatype_ = typename is_mpi_datatype::type > + struct isend; + + template + struct isend + { + request operator()(MPI_Comm comm, int dest, int tag, const T& x) const + { + request r; + typedef mpi_datatype Datatype; + MPI_Isend((void*) Datatype::address(x), + Datatype::count(x), + Datatype::datatype(), + dest, tag, comm, &r.r); + return r; + } + }; + + // irecv + template< class T, class is_mpi_datatype_ = typename is_mpi_datatype::type > + struct irecv; + + template + struct irecv + { + request operator()(MPI_Comm comm, int source, int tag, T& x) const + { + request r; + typedef mpi_datatype Datatype; + MPI_Irecv(Datatype::address(x), + Datatype::count(x), + Datatype::datatype(), + source, tag, comm, &r.r); + return r; + } + }; +} +} +} diff --git a/diy/include/diy/mpi/request.hpp b/diy/include/diy/mpi/request.hpp new file mode 100644 index 000000000..23b11816e --- /dev/null +++ b/diy/include/diy/mpi/request.hpp @@ -0,0 +1,26 @@ +namespace diy +{ +namespace mpi +{ + struct request + { + status wait() { status s; MPI_Wait(&r, &s.s); return s; } + inline + optional test(); + void cancel() { MPI_Cancel(&r); } + + MPI_Request r; + }; +} +} + +diy::mpi::optional +diy::mpi::request::test() +{ + status s; + int flag; + MPI_Test(&r, &flag, &s.s); + if (flag) + return s; + return optional(); +} diff --git a/diy/include/diy/mpi/status.hpp b/diy/include/diy/mpi/status.hpp new file mode 100644 index 000000000..aab500c31 --- /dev/null +++ b/diy/include/diy/mpi/status.hpp @@ -0,0 +1,30 @@ +namespace diy +{ +namespace mpi +{ + struct status + { + int source() const { return s.MPI_SOURCE; } + int tag() const { return s.MPI_TAG; } + int error() const { return s.MPI_ERROR; } + bool cancelled() const { int flag; MPI_Test_cancelled(const_cast(&s), &flag); return flag; } + + template + int count() const; + + operator MPI_Status&() { return s; } + operator const MPI_Status&() const { return s; } + + MPI_Status s; + }; +} +} + +template +int +diy::mpi::status::count() const +{ + int c; + MPI_Get_count(const_cast(&s), detail::get_mpi_datatype(), &c); + return c; +} diff --git a/diy/include/diy/no-thread.hpp b/diy/include/diy/no-thread.hpp new file mode 100644 index 000000000..fd7af88ae --- /dev/null +++ b/diy/include/diy/no-thread.hpp @@ -0,0 +1,38 @@ +#ifndef DIY_NO_THREAD_HPP +#define DIY_NO_THREAD_HPP + +// replicates only the parts of the threading interface that we use +// executes everything in a single thread + +namespace diy +{ + struct thread + { + thread(void (*f)(void *), void* args): + f_(f), args_(args) {} + + void join() { f_(args_); } + + static unsigned hardware_concurrency() { return 1; } + + void (*f_)(void*); + void* args_; + }; + + struct mutex {}; + struct fast_mutex {}; + struct recursive_mutex {}; + + template + struct lock_guard + { + lock_guard(T&) {} + }; + + namespace this_thread + { + inline unsigned long int get_id() { return 0; } + } +} + +#endif diff --git a/diy/include/diy/partners/all-reduce.hpp b/diy/include/diy/partners/all-reduce.hpp new file mode 100644 index 000000000..e34066595 --- /dev/null +++ b/diy/include/diy/partners/all-reduce.hpp @@ -0,0 +1,72 @@ +#ifndef DIY_PARTNERS_ALL_REDUCE_HPP +#define DIY_PARTNERS_ALL_REDUCE_HPP + +#include "merge.hpp" + +namespace diy +{ + +class Master; + +//! Allreduce (reduction with results broadcasted to all blocks) is +//! implemented as two merge reductions, with incoming and outgoing items swapped in second one. +//! Ie, follows merge reduction up and down the merge tree + +/** + * \ingroup Communication + * \brief Partners for all-reduce + * + */ +struct RegularAllReducePartners: public RegularMergePartners +{ + typedef RegularMergePartners Parent; //!< base class merge reduction + + //! contiguous parameter indicates whether to match partners contiguously or in a round-robin fashion; + //! contiguous is useful when data needs to be united; + //! round-robin is useful for vector-"halving" + template + RegularAllReducePartners(const Decomposer& decomposer, //!< domain decomposition + int k, //!< target k value + bool contiguous = true //!< distance doubling (true) or halving (false) + ): + Parent(decomposer, k, contiguous) {} + RegularAllReducePartners(const DivisionVector& divs,//!< explicit division vector + const KVSVector& kvs, //!< explicit k vector + bool contiguous = true //!< distance doubling (true) or halving (false) + ): + Parent(divs, kvs, contiguous) {} + + //! returns total number of rounds + size_t rounds() const { return 2*Parent::rounds(); } + //! returns size of a group of partners in a given round + int size(int round) const { return Parent::size(parent_round(round)); } + //! returns dimension (direction of partners in a regular grid) in a given round + int dim(int round) const { return Parent::dim(parent_round(round)); } + //! returns whether a given block in a given round has dropped out of the merge yet or not + inline bool active(int round, int gid, const Master& m) const { return Parent::active(parent_round(round), gid, m); } + //! returns what the current round would be in the first or second parent merge reduction + int parent_round(int round) const { return round < (int) Parent::rounds() ? round : rounds() - round; } + + // incoming is only valid for an active gid; it will only be called with an active gid + inline void incoming(int round, int gid, std::vector& partners, const Master& m) const + { + if (round <= (int) Parent::rounds()) + Parent::incoming(round, gid, partners, m); + else + Parent::outgoing(parent_round(round), gid, partners, m); + } + + inline void outgoing(int round, int gid, std::vector& partners, const Master& m) const + { + if (round < (int) Parent::rounds()) + Parent::outgoing(round, gid, partners, m); + else + Parent::incoming(parent_round(round), gid, partners, m); + } +}; + +} // diy + +#endif + + diff --git a/diy/include/diy/partners/broadcast.hpp b/diy/include/diy/partners/broadcast.hpp new file mode 100644 index 000000000..d3f565f82 --- /dev/null +++ b/diy/include/diy/partners/broadcast.hpp @@ -0,0 +1,62 @@ +#ifndef DIY_PARTNERS_BROADCAST_HPP +#define DIY_PARTNERS_BROADCAST_HPP + +#include "merge.hpp" + +namespace diy +{ + +class Master; + +/** + * \ingroup Communication + * \brief Partners for broadcast + * + */ +struct RegularBroadcastPartners: public RegularMergePartners +{ + typedef RegularMergePartners Parent; //!< base class merge reduction + + //! contiguous parameter indicates whether to match partners contiguously or in a round-robin fashion; + //! contiguous is useful when data needs to be united; + //! round-robin is useful for vector-"halving" + template + RegularBroadcastPartners(const Decomposer& decomposer, //!< domain decomposition + int k, //!< target k value + bool contiguous = true //!< distance doubling (true) or halving (false) + ): + Parent(decomposer, k, contiguous) {} + RegularBroadcastPartners(const DivisionVector& divs,//!< explicit division vector + const KVSVector& kvs, //!< explicit k vector + bool contiguous = true //!< distance doubling (true) or halving (false) + ): + Parent(divs, kvs, contiguous) {} + + //! returns total number of rounds + size_t rounds() const { return Parent::rounds(); } + //! returns size of a group of partners in a given round + int size(int round) const { return Parent::size(parent_round(round)); } + //! returns dimension (direction of partners in a regular grid) in a given round + int dim(int round) const { return Parent::dim(parent_round(round)); } + //! returns whether a given block in a given round has dropped out of the merge yet or not + inline bool active(int round, int gid, const Master& m) const { return Parent::active(parent_round(round), gid, m); } + //! returns what the current round would be in the first or second parent merge reduction + int parent_round(int round) const { return rounds() - round; } + + // incoming is only valid for an active gid; it will only be called with an active gid + inline void incoming(int round, int gid, std::vector& partners, const Master& m) const + { + Parent::outgoing(parent_round(round), gid, partners, m); + } + + inline void outgoing(int round, int gid, std::vector& partners, const Master& m) const + { + Parent::incoming(parent_round(round), gid, partners, m); + } +}; + +} // diy + +#endif + + diff --git a/diy/include/diy/partners/common.hpp b/diy/include/diy/partners/common.hpp new file mode 100644 index 000000000..43f8297a0 --- /dev/null +++ b/diy/include/diy/partners/common.hpp @@ -0,0 +1,204 @@ +#ifndef DIY_PARTNERS_COMMON_HPP +#define DIY_PARTNERS_COMMON_HPP + +#include "../decomposition.hpp" +#include "../types.hpp" + +namespace diy +{ + +struct RegularPartners +{ + // The record of group size per round in a dimension + struct DimK + { + DimK(int dim_, int k_): + dim(dim_), size(k_) {} + + int dim; + int size; // group size + }; + + typedef std::vector CoordVector; + typedef std::vector DivisionVector; + typedef std::vector KVSVector; + + // The part of RegularDecomposer that we need works the same with either Bounds (so we fix them arbitrarily) + typedef DiscreteBounds Bounds; + typedef RegularDecomposer Decomposer; + + template + RegularPartners(const Decomposer_& decomposer, int k, bool contiguous = true): + divisions_(decomposer.divisions), + contiguous_(contiguous) { factor(k, divisions_, kvs_); fill_steps(); } + RegularPartners(const DivisionVector& divs, + const KVSVector& kvs, + bool contiguous = true): + divisions_(divs), kvs_(kvs), + contiguous_(contiguous) { fill_steps(); } + + size_t rounds() const { return kvs_.size(); } + int size(int round) const { return kvs_[round].size; } + int dim(int round) const { return kvs_[round].dim; } + + int step(int round) const { return steps_[round]; } + + const DivisionVector& divisions() const { return divisions_; } + const KVSVector& kvs() const { return kvs_; } + bool contiguous() const { return contiguous_; } + + static + inline void factor(int k, const DivisionVector& divisions, KVSVector& kvs); + + inline void fill(int round, int gid, std::vector& partners) const; + inline int group_position(int round, int c, int step) const; + + private: + inline void fill_steps(); + static + inline void factor(int k, int tot_b, std::vector& kvs); + + DivisionVector divisions_; + KVSVector kvs_; + bool contiguous_; + std::vector steps_; +}; + +} + +void +diy::RegularPartners:: +fill_steps() +{ + if (contiguous_) + { + std::vector cur_steps(divisions().size(), 1); + + for (size_t r = 0; r < rounds(); ++r) + { + steps_.push_back(cur_steps[kvs_[r].dim]); + cur_steps[kvs_[r].dim] *= kvs_[r].size; + } + } else + { + std::vector cur_steps(divisions().begin(), divisions().end()); + for (size_t r = 0; r < rounds(); ++r) + { + cur_steps[kvs_[r].dim] /= kvs_[r].size; + steps_.push_back(cur_steps[kvs_[r].dim]); + } + } +} + +void +diy::RegularPartners:: +fill(int round, int gid, std::vector& partners) const +{ + const DimK& kv = kvs_[round]; + partners.reserve(kv.size); + + int step = this->step(round); // gids jump by this much in the current round + + CoordVector coords; + Decomposer::gid_to_coords(gid, coords, divisions_); + int c = coords[kv.dim]; + int pos = group_position(round, c, step); + + int partner = c - pos * step; + coords[kv.dim] = partner; + int partner_gid = Decomposer::coords_to_gid(coords, divisions_); + partners.push_back(partner_gid); + + for (int k = 1; k < kv.size; ++k) + { + partner += step; + coords[kv.dim] = partner; + int partner_gid = Decomposer::coords_to_gid(coords, divisions_); + partners.push_back(partner_gid); + } +} + +// Tom's GetGrpPos +int +diy::RegularPartners:: +group_position(int round, int c, int step) const +{ + // the second term in the following expression does not simplify to + // (gid - start_b) / kv[r] + // because the division gid / (step * kv[r]) is integer and truncates + // this is exactly what we want + int g = c % step + c / (step * kvs_[round].size) * step; + int p = c / step % kvs_[round].size; + static_cast(g); // shut up the compiler + + // g: group number (output) + // p: position number within the group (output) + return p; +} + +void +diy::RegularPartners:: +factor(int k, const DivisionVector& divisions, KVSVector& kvs) +{ + // factor in each dimension + std::vector< std::vector > tmp_kvs(divisions.size()); + for (unsigned i = 0; i < divisions.size(); ++i) + factor(k, divisions[i], tmp_kvs[i]); + + // interleave the dimensions + std::vector round_per_dim(divisions.size(), 0); + while(true) + { + // TODO: not the most efficient way to do this + bool changed = false; + for (unsigned i = 0; i < divisions.size(); ++i) + { + if (round_per_dim[i] == (int) tmp_kvs[i].size()) + continue; + kvs.push_back(DimK(i, tmp_kvs[i][round_per_dim[i]++])); + changed = true; + } + if (!changed) + break; + } +} + +// Tom's FactorK +void +diy::RegularPartners:: +factor(int k, int tot_b, std::vector& kv) +{ + int rem = tot_b; // unfactored remaining portion of tot_b + int j; + + while (rem > 1) + { + // remainder is divisible by k + if (rem % k == 0) + { + kv.push_back(k); + rem /= k; + } + // if not, start at k and linearly look for smaller factors down to 2 + else + { + for (j = k - 1; j > 1; j--) + { + if (rem % j == 0) + { + kv.push_back(j); + rem /= k; + break; + } + } + if (j == 1) + { + kv.push_back(rem); + rem = 1; + } + } // else + } // while +} + + +#endif diff --git a/diy/include/diy/partners/merge.hpp b/diy/include/diy/partners/merge.hpp new file mode 100644 index 000000000..c6be42533 --- /dev/null +++ b/diy/include/diy/partners/merge.hpp @@ -0,0 +1,60 @@ +#ifndef DIY_PARTNERS_MERGE_HPP +#define DIY_PARTNERS_MERGE_HPP + +#include "common.hpp" + +namespace diy +{ + +class Master; + +/** + * \ingroup Communication + * \brief Partners for merge-reduce + * + */ +struct RegularMergePartners: public RegularPartners +{ + typedef RegularPartners Parent; + + // contiguous parameter indicates whether to match partners contiguously or in a round-robin fashion; + // contiguous is useful when data needs to be united; + // round-robin is useful for vector-"halving" + template + RegularMergePartners(const Decomposer& decomposer, //!< domain decomposition + int k, //!< target k value + bool contiguous = true //!< distance doubling (true) or halving (false) + ): + Parent(decomposer, k, contiguous) {} + RegularMergePartners(const DivisionVector& divs, //!< explicit division vector + const KVSVector& kvs, //!< explicit k vector + bool contiguous = true //!< distance doubling (true) or halving (false) + ): + Parent(divs, kvs, contiguous) {} + + inline bool active(int round, int gid, const Master&) const; + + // incoming is only valid for an active gid; it will only be called with an active gid + inline void incoming(int round, int gid, std::vector& partners, const Master&) const { Parent::fill(round - 1, gid, partners); } + // this is a lazy implementation of outgoing, but it reuses the existing code + inline void outgoing(int round, int gid, std::vector& partners, const Master&) const { std::vector tmp; Parent::fill(round, gid, tmp); partners.push_back(tmp[0]); } +}; + +} // diy + +bool +diy::RegularMergePartners:: +active(int round, int gid, const Master&) const +{ + CoordVector coords; + Decomposer::gid_to_coords(gid, coords, divisions()); + + for (int r = 0; r < round; ++r) + if (Parent::group_position(r, coords[kvs()[r].dim], step(r)) != 0) + return false; + + return true; +} + +#endif + diff --git a/diy/include/diy/partners/swap.hpp b/diy/include/diy/partners/swap.hpp new file mode 100644 index 000000000..cc3b3e494 --- /dev/null +++ b/diy/include/diy/partners/swap.hpp @@ -0,0 +1,43 @@ +#ifndef DIY_PARTNERS_SWAP_HPP +#define DIY_PARTNERS_SWAP_HPP + +#include "common.hpp" + +namespace diy +{ + +class Master; + +/** + * \ingroup Communication + * \brief Partners for swap-reduce + * + */ +struct RegularSwapPartners: public RegularPartners +{ + typedef RegularPartners Parent; + + // contiguous parameter indicates whether to match partners contiguously or in a round-robin fashion; + // contiguous is useful when data needs to be united; + // round-robin is useful for vector-"halving" + template + RegularSwapPartners(const Decomposer& decomposer, //!< domain decomposition + int k, //!< target k value + bool contiguous = true //!< distance halving (true) or doubling (false) + ): + Parent(decomposer, k, contiguous) {} + RegularSwapPartners(const DivisionVector& divs, //!< explicit division vector + const KVSVector& kvs, //!< explicit k vector + bool contiguous = true //!< distance halving (true) or doubling (false) + ): + Parent(divs, kvs, contiguous) {} + + bool active(int round, int gid, const Master&) const { return true; } // in swap-reduce every block is always active + + void incoming(int round, int gid, std::vector& partners, const Master&) const { Parent::fill(round - 1, gid, partners); } + void outgoing(int round, int gid, std::vector& partners, const Master&) const { Parent::fill(round, gid, partners); } +}; + +} // diy + +#endif diff --git a/diy/include/diy/pick.hpp b/diy/include/diy/pick.hpp new file mode 100644 index 000000000..5f9d8d0e8 --- /dev/null +++ b/diy/include/diy/pick.hpp @@ -0,0 +1,137 @@ +#ifndef DIY_PICK_HPP +#define DIY_PICK_HPP + +#include "link.hpp" + +namespace diy +{ + template + void near(const RegularLink& link, const Point& p, float r, OutIter out, + const Bounds& domain); + + template + void in(const RegularLink& link, const Point& p, OutIter out, const Bounds& domain); + + template + float distance(int dim, const Bounds& bounds, const Point& p); + + template + inline + float distance(int dim, const Bounds& bounds1, const Bounds& bounds2); + + template + void wrap_bounds(Bounds& bounds, Direction wrap_dir, const Bounds& domain, int dim); +} + +//! Finds the neighbors within radius r of a target point. +template +void +diy:: +near(const RegularLink& link, //!< neighbors + const Point& p, //!< target point (must be in current block) + float r, //!< target radius (>= 0.0) + OutIter out, //!< insert iterator for output set of neighbors + const Bounds& domain) //!< global domain bounds +{ + Bounds neigh_bounds; // neighbor block bounds + + // for all neighbors of this block + for (int n = 0; n < link.size(); n++) + { + // wrap neighbor bounds, if necessary, otherwise bounds will be unchanged + neigh_bounds = link.bounds(n); + wrap_bounds(neigh_bounds, link.wrap(n), domain, link.dimension()); + + if (distance(link.dimension(), neigh_bounds, p) <= r) + *out++ = n; + } // for all neighbors +} + +//! Find the distance between point `p` and box `bounds`. +template +float +diy:: +distance(int dim, const Bounds& bounds, const Point& p) +{ + float res = 0; + for (int i = 0; i < dim; ++i) + { + // avoids all the annoying case logic by finding + // diff = max(bounds.min[i] - p[i], 0, p[i] - bounds.max[i]) + float diff = 0, d; + + d = bounds.min[i] - p[i]; + if (d > diff) diff = d; + d = p[i] - bounds.max[i]; + if (d > diff) diff = d; + + res += diff*diff; + } + return sqrt(res); +} + +template +float +diy:: +distance(int dim, const Bounds& bounds1, const Bounds& bounds2) +{ + float res = 0; + for (int i = 0; i < dim; ++i) + { + float diff = 0, d; + + float d1 = bounds1.max[i] - bounds2.min[i]; + float d2 = bounds2.max[i] - bounds1.min[i]; + + if (d1 > 0 && d2 > 0) + diff = 0; + else if (d1 <= 0) + diff = -d1; + else if (d2 <= 0) + diff = -d2; + + res += diff*diff; + } + return sqrt(res); +} + +//! Finds the neighbor(s) containing the target point. +template +void +diy:: +in(const RegularLink& link, //!< neighbors + const Point& p, //!< target point + OutIter out, //!< insert iterator for output set of neighbors + const Bounds& domain) //!< global domain bounds +{ + Bounds neigh_bounds; // neighbor block bounds + + // for all neighbors of this block + for (int n = 0; n < link.size(); n++) + { + // wrap neighbor bounds, if necessary, otherwise bounds will be unchanged + neigh_bounds = link.bounds(n); + wrap_bounds(neigh_bounds, link.wrap(n), domain, link.dimension()); + + if (distance(link.dimension(), neigh_bounds, p) == 0) + *out++ = n; + } // for all neighbors +} + +// wraps block bounds +// wrap dir is the wrapping direction from original block to wrapped neighbor block +// overall domain bounds and dimensionality are also needed +template +void +diy:: +wrap_bounds(Bounds& bounds, Direction wrap_dir, const Bounds& domain, int dim) +{ + for (int i = 0; i < dim; ++i) + { + bounds.min[i] += wrap_dir[i] * (domain.max[i] - domain.min[i]); + bounds.max[i] += wrap_dir[i] * (domain.max[i] - domain.min[i]); + } +} + + +#endif diff --git a/diy/include/diy/point.hpp b/diy/include/diy/point.hpp new file mode 100644 index 000000000..cafbe784c --- /dev/null +++ b/diy/include/diy/point.hpp @@ -0,0 +1,120 @@ +#ifndef DIY_POINT_HPP +#define DIY_POINT_HPP + +#include +#include +#include +#include + +#include + +namespace diy +{ + +template +class Point: public std::array +{ + public: + typedef Coordinate_ Coordinate; + typedef std::array ArrayParent; + + typedef Point LPoint; + typedef Point UPoint; + + template + struct rebind { typedef Point type; }; + + public: + Point() { for (unsigned i = 0; i < D; ++i) (*this)[i] = 0; } + Point(const ArrayParent& a): + ArrayParent(a) {} + template Point(const Point& p) { for (size_t i = 0; i < D; ++i) (*this)[i] = p[i]; } + template Point(const T* a) { for (unsigned i = 0; i < D; ++i) (*this)[i] = a[i]; } + template Point(const std::vector& a) { for (unsigned i = 0; i < D; ++i) (*this)[i] = a[i]; } + Point(std::initializer_list lst) { unsigned i = 0; for (Coordinate x : lst) (*this)[i++] = x; } + + Point(Point&&) =default; + Point(const Point&) =default; + Point& operator=(const Point&) =default; + + static constexpr + unsigned dimension() { return D; } + + static Point zero() { return Point(); } + static Point one() { Point p; for (unsigned i = 0; i < D; ++i) p[i] = 1; return p; } + + LPoint drop(int dim) const { LPoint p; unsigned c = 0; for (unsigned i = 0; i < D; ++i) { if (i == dim) continue; p[c++] = (*this)[i]; } return p; } + UPoint lift(int dim, Coordinate x) const { UPoint p; for (unsigned i = 0; i < D+1; ++i) { if (i < dim) p[i] = (*this)[i]; else if (i == dim) p[i] = x; else if (i > dim) p[i] = (*this)[i-1]; } return p; } + + using ArrayParent::operator[]; + + Point& operator+=(const Point& y) { for (unsigned i = 0; i < D; ++i) (*this)[i] += y[i]; return *this; } + Point& operator-=(const Point& y) { for (unsigned i = 0; i < D; ++i) (*this)[i] -= y[i]; return *this; } + Point& operator*=(Coordinate a) { for (unsigned i = 0; i < D; ++i) (*this)[i] *= a; return *this; } + Point& operator/=(Coordinate a) { for (unsigned i = 0; i < D; ++i) (*this)[i] /= a; return *this; } + + Coordinate norm() const { return (*this)*(*this); } + + std::ostream& operator<<(std::ostream& out) const { out << (*this)[0]; for (unsigned i = 1; i < D; ++i) out << " " << (*this)[i]; return out; } + std::istream& operator>>(std::istream& in); + + friend + Point operator+(Point x, const Point& y) { x += y; return x; } + + friend + Point operator-(Point x, const Point& y) { x -= y; return x; } + + friend + Point operator/(Point x, Coordinate y) { x /= y; return x; } + + friend + Point operator*(Point x, Coordinate y) { x *= y; return x; } + + friend + Point operator*(Coordinate y, Point x) { x *= y; return x; } + + friend + Coordinate operator*(const Point& x, const Point& y) { Coordinate n = 0; for (size_t i = 0; i < D; ++i) n += x[i] * y[i]; return n; } + + template + friend + Coordinate operator*(const Point& x, const Point& y) { Coordinate n = 0; for (size_t i = 0; i < D; ++i) n += x[i] * y[i]; return n; } +}; + +template +std::istream& +Point:: +operator>>(std::istream& in) +{ + std::string point_str; + in >> point_str; // read until ' ' + std::stringstream ps(point_str); + + char x; + for (unsigned i = 0; i < dimension(); ++i) + { + ps >> (*this)[i]; + ps >> x; + } + + return in; +} + + +template +Coordinate norm2(const Point& p) +{ Coordinate res = 0; for (unsigned i = 0; i < D; ++i) res += p[i]*p[i]; return res; } + +template +std::ostream& +operator<<(std::ostream& out, const Point& p) +{ return p.operator<<(out); } + +template +std::istream& +operator>>(std::istream& in, Point& p) +{ return p.operator>>(in); } + +} + +#endif // DIY_POINT_HPP diff --git a/diy/include/diy/proxy.hpp b/diy/include/diy/proxy.hpp new file mode 100644 index 000000000..0160e0605 --- /dev/null +++ b/diy/include/diy/proxy.hpp @@ -0,0 +1,228 @@ +#ifndef DIY_PROXY_HPP +#define DIY_PROXY_HPP + + +namespace diy +{ + //! Communication proxy, used for enqueueing and dequeueing items for future exchange. + struct Master::Proxy + { + template + struct EnqueueIterator; + + Proxy(Master* master, int gid): + gid_(gid), + master_(master), + incoming_(&master->incoming(gid)), + outgoing_(&master->outgoing(gid)), + collectives_(&master->collectives(gid)) {} + + int gid() const { return gid_; } + + //! Enqueue data whose size can be determined automatically, e.g., an STL vector. + template + void enqueue(const BlockID& to, //!< target block (gid,proc) + const T& x, //!< data (eg. STL vector) + void (*save)(BinaryBuffer&, const T&) = &::diy::save //!< optional serialization function + ) const + { OutgoingQueues& out = *outgoing_; save(out[to], x); } + + //! Enqueue data whose size is given explicitly by the user, e.g., an array. + template + void enqueue(const BlockID& to, //!< target block (gid,proc) + const T* x, //!< pointer to the data (eg. address of start of vector) + size_t n, //!< size in data elements (eg. ints) + void (*save)(BinaryBuffer&, const T&) = &::diy::save //!< optional serialization function + ) const; + + //! Dequeue data whose size can be determined automatically (e.g., STL vector) and that was + //! previously enqueued so that diy knows its size when it is received. + //! In this case, diy will allocate the receive buffer; the user does not need to do so. + template + void dequeue(int from, //!< target block gid + T& x, //!< data (eg. STL vector) + void (*load)(BinaryBuffer&, T&) = &::diy::load //!< optional serialization function + ) const + { IncomingQueues& in = *incoming_; load(in[from], x); } + + //! Dequeue an array of data whose size is given explicitly by the user. + //! In this case, the user needs to allocate the receive buffer prior to calling dequeue. + template + void dequeue(int from, //!< target block gid + T* x, //!< pointer to the data (eg. address of start of vector) + size_t n, //!< size in data elements (eg. ints) + void (*load)(BinaryBuffer&, T&) = &::diy::load //!< optional serialization function + ) const; + + template + EnqueueIterator enqueuer(const T& x, + void (*save)(BinaryBuffer&, const T&) = &::diy::save) const + { return EnqueueIterator(this, x, save); } + + IncomingQueues* incoming() const { return incoming_; } + MemoryBuffer& incoming(int from) const { return (*incoming_)[from]; } + inline void incoming(std::vector& v) const; // fill v with every gid from which we have a message + + OutgoingQueues* outgoing() const { return outgoing_; } + MemoryBuffer& outgoing(const BlockID& to) const { return (*outgoing_)[to]; } + +/** + * \ingroup Communication + * \brief Post an all-reduce collective using an existing communication proxy. + * Available operators are: + * maximum, minimum, std::plus, std::multiplies, std::logical_and, and + * std::logical_or. + */ + template + inline void all_reduce(const T& in, //!< local value being reduced + Op op //!< operator + ) const; +/** + * \ingroup Communication + * \brief Return the result of a proxy collective without popping it off the collectives list (same result would be returned multiple times). The list can be cleared with collectives()->clear(). + */ + template + inline T read() const; +/** + * \ingroup Communication + * \brief Return the result of a proxy collective; result is popped off the collectives list. + */ + template + inline T get() const; + + template + inline void scratch(const T& in) const; + +/** + * \ingroup Communication + * \brief Return the list of proxy collectives (values and operations) + */ + CollectivesList* collectives() const { return collectives_; } + + Master* master() const { return master_; } + + private: + int gid_; + Master* master_; + IncomingQueues* incoming_; + OutgoingQueues* outgoing_; + CollectivesList* collectives_; + }; + + template + struct Master::Proxy::EnqueueIterator: + public std::iterator + { + typedef void (*SaveT)(BinaryBuffer&, const T&); + + EnqueueIterator(const Proxy* proxy, const T& x, + SaveT save = &::diy::save): + proxy_(proxy), x_(x), save_(save) {} + + EnqueueIterator& operator=(const BlockID& to) { proxy_->enqueue(to, x_, save_); return *this; } + EnqueueIterator& operator*() { return *this; } + EnqueueIterator& operator++() { return *this; } + EnqueueIterator& operator++(int) { return *this; } + + private: + const Proxy* proxy_; + const T& x_; + SaveT save_; + + }; + + struct Master::ProxyWithLink: public Master::Proxy + { + ProxyWithLink(const Proxy& proxy, + void* block, + Link* link): + Proxy(proxy), + block_(block), + link_(link) {} + + Link* link() const { return link_; } + void* block() const { return block_; } + + private: + void* block_; + Link* link_; + }; +} + + +void +diy::Master::Proxy:: +incoming(std::vector& v) const +{ + for (IncomingQueues::const_iterator it = incoming_->begin(); it != incoming_->end(); ++it) + v.push_back(it->first); +} + +template +void +diy::Master::Proxy:: +all_reduce(const T& in, Op op) const +{ + collectives_->push_back(Collective(new detail::AllReduceOp(in, op))); +} + +template +T +diy::Master::Proxy:: +read() const +{ + T res; + collectives_->front().result_out(&res); + return res; +} + +template +T +diy::Master::Proxy:: +get() const +{ + T res = read(); + collectives_->pop_front(); + return res; +} + +template +void +diy::Master::Proxy:: +scratch(const T& in) const +{ + collectives_->push_back(Collective(new detail::Scratch(in))); +} + +template +void +diy::Master::Proxy:: +enqueue(const BlockID& to, const T* x, size_t n, + void (*save)(BinaryBuffer&, const T&)) const +{ + OutgoingQueues& out = *outgoing_; + BinaryBuffer& bb = out[to]; + if (save == (void (*)(BinaryBuffer&, const T&)) &::diy::save) + diy::save(bb, x, n); // optimized for unspecialized types + else + for (size_t i = 0; i < n; ++i) + save(bb, x[i]); +} + +template +void +diy::Master::Proxy:: +dequeue(int from, T* x, size_t n, + void (*load)(BinaryBuffer&, T&)) const +{ + IncomingQueues& in = *incoming_; + BinaryBuffer& bb = in[from]; + if (load == (void (*)(BinaryBuffer&, T&)) &::diy::load) + diy::load(bb, x, n); // optimized for unspecialized types + else + for (size_t i = 0; i < n; ++i) + load(bb, x[i]); +} + + +#endif diff --git a/diy/include/diy/reduce-operations.hpp b/diy/include/diy/reduce-operations.hpp new file mode 100644 index 000000000..629824da5 --- /dev/null +++ b/diy/include/diy/reduce-operations.hpp @@ -0,0 +1,32 @@ +#ifndef DIY_REDUCE_OPERATIONS_HPP +#define DIY_REDUCE_OPERATIONS_HPP + +#include "reduce.hpp" +#include "partners/swap.hpp" +#include "detail/reduce/all-to-all.hpp" + +namespace diy +{ + +/** + * \ingroup Communication + * \brief all to all reduction + * + */ +template +void +all_to_all(Master& master, //!< block owner + const Assigner& assigner, //!< global block locator (maps gid to proc) + const Op& op, //!< user-defined operation called to enqueue and dequeue items + int k = 2 //!< reduction fanout + ) +{ + auto scoped = master.prof.scoped("all_to_all"); + RegularDecomposer decomposer(1, interval(0,assigner.nblocks()-1), assigner.nblocks()); + RegularSwapPartners partners(decomposer, k, false); + reduce(master, assigner, partners, detail::AllToAllReduce(op, assigner), detail::SkipIntermediate(partners.rounds())); +} + +} + +#endif diff --git a/diy/include/diy/reduce.hpp b/diy/include/diy/reduce.hpp new file mode 100644 index 000000000..6d47d7930 --- /dev/null +++ b/diy/include/diy/reduce.hpp @@ -0,0 +1,216 @@ +#ifndef DIY_REDUCE_HPP +#define DIY_REDUCE_HPP + +#include +#include "master.hpp" +#include "assigner.hpp" +#include "detail/block_traits.hpp" +#include "log.hpp" + +namespace diy +{ +//! Enables communication within a group during a reduction. +//! DIY creates the ReduceProxy for you in diy::reduce() +//! and provides a reference to ReduceProxy each time the user's reduction function is called +struct ReduceProxy: public Master::Proxy +{ + typedef std::vector GIDVector; + + ReduceProxy(const Master::Proxy& proxy, //!< parent proxy + void* block, //!< diy block + unsigned round, //!< current round + const Assigner& assigner, //!< assigner + const GIDVector& incoming_gids, //!< incoming gids in this group + const GIDVector& outgoing_gids): //!< outgoing gids in this group + Master::Proxy(proxy), + block_(block), + round_(round), + assigner_(assigner) + { + // setup in_link + for (unsigned i = 0; i < incoming_gids.size(); ++i) + { + BlockID nbr; + nbr.gid = incoming_gids[i]; + nbr.proc = assigner.rank(nbr.gid); + in_link_.add_neighbor(nbr); + } + + // setup out_link + for (unsigned i = 0; i < outgoing_gids.size(); ++i) + { + BlockID nbr; + nbr.gid = outgoing_gids[i]; + nbr.proc = assigner.rank(nbr.gid); + out_link_.add_neighbor(nbr); + } + } + + ReduceProxy(const Master::Proxy& proxy, //!< parent proxy + void* block, //!< diy block + unsigned round, //!< current round + const Assigner& assigner, + const Link& in_link, + const Link& out_link): + Master::Proxy(proxy), + block_(block), + round_(round), + assigner_(assigner), + in_link_(in_link), + out_link_(out_link) + {} + + //! returns pointer to block + void* block() const { return block_; } + //! returns current round number + unsigned round() const { return round_; } + //! returns incoming link + const Link& in_link() const { return in_link_; } + //! returns outgoing link + const Link& out_link() const { return out_link_; } + //! returns total number of blocks + int nblocks() const { return assigner_.nblocks(); } + //! returns the assigner + const Assigner& assigner() const { return assigner_; } + + //! advanced: change current round number + void set_round(unsigned r) { round_ = r; } + + private: + void* block_; + unsigned round_; + const Assigner& assigner_; + + Link in_link_; + Link out_link_; +}; + +namespace detail +{ + template + struct ReductionFunctor; + + template + struct SkipInactiveOr; + + struct ReduceNeverSkip + { + bool operator()(int round, int lid, const Master& master) const { return false; } + }; +} + +/** + * \ingroup Communication + * \brief Implementation of the reduce communication pattern (includes + * swap-reduce, merge-reduce, and any other global communication). + * + */ +template +void reduce(Master& master, //!< master object + const Assigner& assigner, //!< assigner object + const Partners& partners, //!< partners object + const Reduce& reduce, //!< reduction callback function + const Skip& skip) //!< object determining whether a block should be skipped +{ + auto log = get_logger(); + + int original_expected = master.expected(); + + using Block = typename detail::block_traits::type; + + unsigned round; + for (round = 0; round < partners.rounds(); ++round) + { + log->debug("Round {}", round); + master.foreach(detail::ReductionFunctor(round, reduce, partners, assigner), + detail::SkipInactiveOr(round, partners, skip)); + master.execute(); + + int expected = 0; + for (unsigned i = 0; i < master.size(); ++i) + { + if (partners.active(round + 1, master.gid(i), master)) + { + std::vector incoming_gids; + partners.incoming(round + 1, master.gid(i), incoming_gids, master); + expected += incoming_gids.size(); + master.incoming(master.gid(i)).clear(); + } + } + master.set_expected(expected); + master.flush(); + } + // final round + log->debug("Round {}", round); + master.foreach(detail::ReductionFunctor(round, reduce, partners, assigner), + detail::SkipInactiveOr(round, partners, skip)); + + master.set_expected(original_expected); +} + +/** + * \ingroup Communication + * \brief Implementation of the reduce communication pattern (includes + * swap-reduce, merge-reduce, and any other global communication). + * + */ +template +void reduce(Master& master, //!< master object + const Assigner& assigner, //!< assigner object + const Partners& partners, //!< partners object + const Reduce& reducer) //!< reduction callback function +{ + reduce(master, assigner, partners, reducer, detail::ReduceNeverSkip()); +} + +namespace detail +{ + template + struct ReductionFunctor + { + using Callback = std::function; + + ReductionFunctor(unsigned round_, const Callback& reduce_, const Partners& partners_, const Assigner& assigner_): + round(round_), reduce(reduce_), partners(partners_), assigner(assigner_) {} + + void operator()(Block* b, const Master::ProxyWithLink& cp) const + { + if (!partners.active(round, cp.gid(), *cp.master())) return; + + std::vector incoming_gids, outgoing_gids; + if (round > 0) + partners.incoming(round, cp.gid(), incoming_gids, *cp.master()); // receive from the previous round + if (round < partners.rounds()) + partners.outgoing(round, cp.gid(), outgoing_gids, *cp.master()); // send to the next round + + ReduceProxy rp(cp, b, round, assigner, incoming_gids, outgoing_gids); + reduce(b, rp, partners); + + // touch the outgoing queues to make sure they exist + Master::OutgoingQueues& outgoing = *cp.outgoing(); + if (outgoing.size() < (size_t) rp.out_link().size()) + for (int j = 0; j < rp.out_link().size(); ++j) + outgoing[rp.out_link().target(j)]; // touch the outgoing queue, creating it if necessary + } + + unsigned round; + Callback reduce; + Partners partners; + const Assigner& assigner; + }; + + template + struct SkipInactiveOr + { + SkipInactiveOr(int round_, const Partners& partners_, const Skip& skip_): + round(round_), partners(partners_), skip(skip_) {} + bool operator()(int i, const Master& master) const { return !partners.active(round, master.gid(i), master) || skip(round, i, master); } + int round; + const Partners& partners; + Skip skip; + }; +} + +} // diy + +#endif // DIY_REDUCE_HPP diff --git a/diy/include/diy/serialization.hpp b/diy/include/diy/serialization.hpp new file mode 100644 index 000000000..25640255d --- /dev/null +++ b/diy/include/diy/serialization.hpp @@ -0,0 +1,456 @@ +#ifndef DIY_SERIALIZATION_HPP +#define DIY_SERIALIZATION_HPP + +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include // this is used for a safety check for default serialization + +namespace diy +{ + //! A serialization buffer. \ingroup Serialization + struct BinaryBuffer + { + virtual void save_binary(const char* x, size_t count) =0; //!< copy `count` bytes from `x` into the buffer + virtual void load_binary(char* x, size_t count) =0; //!< copy `count` bytes into `x` from the buffer + virtual void load_binary_back(char* x, size_t count) =0; //!< copy `count` bytes into `x` from the back of the buffer + }; + + struct MemoryBuffer: public BinaryBuffer + { + MemoryBuffer(size_t position_ = 0): + position(position_) {} + + virtual inline void save_binary(const char* x, size_t count) override; //!< copy `count` bytes from `x` into the buffer + virtual inline void load_binary(char* x, size_t count) override; //!< copy `count` bytes into `x` from the buffer + virtual inline void load_binary_back(char* x, size_t count) override; //!< copy `count` bytes into `x` from the back of the buffer + + void clear() { buffer.clear(); reset(); } + void wipe() { std::vector().swap(buffer); reset(); } + void reset() { position = 0; } + void skip(size_t s) { position += s; } + void swap(MemoryBuffer& o) { std::swap(position, o.position); buffer.swap(o.buffer); } + bool empty() const { return buffer.empty(); } + size_t size() const { return buffer.size(); } + void reserve(size_t s) { buffer.reserve(s); } + operator bool() const { return position < buffer.size(); } + + //! copy a memory buffer from one buffer to another, bypassing making a temporary copy first + inline static void copy(MemoryBuffer& from, MemoryBuffer& to); + + //! multiplier used for the geometric growth of the container + static float growth_multiplier() { return 1.5; } + + // simple file IO + void write(const std::string& fn) const { std::ofstream out(fn.c_str()); out.write(&buffer[0], size()); } + void read(const std::string& fn) + { + std::ifstream in(fn.c_str(), std::ios::binary | std::ios::ate); + buffer.resize(in.tellg()); + in.seekg(0); + in.read(&buffer[0], size()); + position = 0; + } + + size_t position; + std::vector buffer; + }; + + namespace detail + { + struct Default {}; + } + + //!\addtogroup Serialization + //!@{ + + /** + * \brief Main interface to serialization, meant to be specialized for the + * types that require special handling. `diy::save()` and `diy::load()` call + * the static member functions of this class. + * + * The default (unspecialized) version copies + * `sizeof(T)` bytes from `&x` to or from `bb` via + * its `diy::BinaryBuffer::save_binary()` and `diy::BinaryBuffer::load_binary()` + * functions. This works out perfectly for plain old data (e.g., simple structs). + * To save a more complicated type, one has to specialize + * `diy::Serialization` for that type. Specializations are already provided for + * `std::vector`, `std::map`, and `std::pair`. + * As a result one can quickly add a specialization of one's own + * + */ + template + struct Serialization: public detail::Default + { +#if defined(__clang__) || (defined(__GNUC__) && __GNUC__ >= 5) + static_assert(std::is_trivially_copyable::value, "Default serialization works only for trivially copyable types"); +#endif + + static void save(BinaryBuffer& bb, const T& x) { bb.save_binary((const char*) &x, sizeof(T)); } + static void load(BinaryBuffer& bb, T& x) { bb.load_binary((char*) &x, sizeof(T)); } + }; + + //! Saves `x` to `bb` by calling `diy::Serialization::save(bb,x)`. + template + void save(BinaryBuffer& bb, const T& x) { Serialization::save(bb, x); } + + //! Loads `x` from `bb` by calling `diy::Serialization::load(bb,x)`. + template + void load(BinaryBuffer& bb, T& x) { Serialization::load(bb, x); } + + //! Optimization for arrays. If `diy::Serialization` is not specialized for `T`, + //! the array will be copied all at once. Otherwise, it's copied element by element. + template + void save(BinaryBuffer& bb, const T* x, size_t n); + + //! Optimization for arrays. If `diy::Serialization` is not specialized for `T`, + //! the array will be filled all at once. Otherwise, it's filled element by element. + template + void load(BinaryBuffer& bb, T* x, size_t n); + + //! Supports only binary data copying (meant for simple footers). + template + void load_back(BinaryBuffer& bb, T& x) { bb.load_binary_back((char*) &x, sizeof(T)); } + + //@} + + + namespace detail + { + template + struct is_default + { + typedef char yes; + typedef int no; + + static yes test(Default*); + static no test(...); + + enum { value = (sizeof(test((T*) 0)) == sizeof(yes)) }; + }; + } + + template + void save(BinaryBuffer& bb, const T* x, size_t n) + { + if (!detail::is_default< Serialization >::value) + for (size_t i = 0; i < n; ++i) + diy::save(bb, x[i]); + else // if Serialization is not specialized for U, just save the binary data + bb.save_binary((const char*) &x[0], sizeof(T)*n); + } + + template + void load(BinaryBuffer& bb, T* x, size_t n) + { + if (!detail::is_default< Serialization >::value) + for (size_t i = 0; i < n; ++i) + diy::load(bb, x[i]); + else // if Serialization is not specialized for U, just load the binary data + bb.load_binary((char*) &x[0], sizeof(T)*n); + } + + + // save/load for MemoryBuffer + template<> + struct Serialization< MemoryBuffer > + { + static void save(BinaryBuffer& bb, const MemoryBuffer& x) + { + diy::save(bb, x.position); + diy::save(bb, &x.buffer[0], x.position); + } + + static void load(BinaryBuffer& bb, MemoryBuffer& x) + { + diy::load(bb, x.position); + x.buffer.resize(x.position); + diy::load(bb, &x.buffer[0], x.position); + } + }; + + // save/load for std::vector + template + struct Serialization< std::vector > + { + typedef std::vector Vector; + + static void save(BinaryBuffer& bb, const Vector& v) + { + size_t s = v.size(); + diy::save(bb, s); + diy::save(bb, &v[0], v.size()); + } + + static void load(BinaryBuffer& bb, Vector& v) + { + size_t s; + diy::load(bb, s); + v.resize(s); + diy::load(bb, &v[0], s); + } + }; + + template + struct Serialization< std::valarray > + { + typedef std::valarray ValArray; + + static void save(BinaryBuffer& bb, const ValArray& v) + { + size_t s = v.size(); + diy::save(bb, s); + diy::save(bb, &v[0], v.size()); + } + + static void load(BinaryBuffer& bb, ValArray& v) + { + size_t s; + diy::load(bb, s); + v.resize(s); + diy::load(bb, &v[0], s); + } + }; + + // save/load for std::string + template<> + struct Serialization< std::string > + { + typedef std::string String; + + static void save(BinaryBuffer& bb, const String& s) + { + size_t sz = s.size(); + diy::save(bb, sz); + diy::save(bb, s.c_str(), sz); + } + + static void load(BinaryBuffer& bb, String& s) + { + size_t sz; + diy::load(bb, sz); + s.resize(sz); + for (size_t i = 0; i < sz; ++i) + { + char c; + diy::load(bb, c); + s[i] = c; + } + } + }; + + // save/load for std::pair + template + struct Serialization< std::pair > + { + typedef std::pair Pair; + + static void save(BinaryBuffer& bb, const Pair& p) + { + diy::save(bb, p.first); + diy::save(bb, p.second); + } + + static void load(BinaryBuffer& bb, Pair& p) + { + diy::load(bb, p.first); + diy::load(bb, p.second); + } + }; + + // save/load for std::map + template + struct Serialization< std::map > + { + typedef std::map Map; + + static void save(BinaryBuffer& bb, const Map& m) + { + size_t s = m.size(); + diy::save(bb, s); + for (typename std::map::const_iterator it = m.begin(); it != m.end(); ++it) + diy::save(bb, *it); + } + + static void load(BinaryBuffer& bb, Map& m) + { + size_t s; + diy::load(bb, s); + for (size_t i = 0; i < s; ++i) + { + K k; + diy::load(bb, k); + diy::load(bb, m[k]); + } + } + }; + + // save/load for std::set + template + struct Serialization< std::set > + { + typedef std::set Set; + + static void save(BinaryBuffer& bb, const Set& m) + { + size_t s = m.size(); + diy::save(bb, s); + for (typename std::set::const_iterator it = m.begin(); it != m.end(); ++it) + diy::save(bb, *it); + } + + static void load(BinaryBuffer& bb, Set& m) + { + size_t s; + diy::load(bb, s); + for (size_t i = 0; i < s; ++i) + { + T p; + diy::load(bb, p); + m.insert(p); + } + } + }; + + // save/load for std::unordered_map + template + struct Serialization< std::unordered_map > + { + typedef std::unordered_map Map; + + static void save(BinaryBuffer& bb, const Map& m) + { + size_t s = m.size(); + diy::save(bb, s); + for (auto& x : m) + diy::save(bb, x); + } + + static void load(BinaryBuffer& bb, Map& m) + { + size_t s; + diy::load(bb, s); + for (size_t i = 0; i < s; ++i) + { + std::pair p; + diy::load(bb, p); + m.emplace(std::move(p)); + } + } + }; + + // save/load for std::unordered_set + template + struct Serialization< std::unordered_set > + { + typedef std::unordered_set Set; + + static void save(BinaryBuffer& bb, const Set& m) + { + size_t s = m.size(); + diy::save(bb, s); + for (auto& x : m) + diy::save(bb, x); + } + + static void load(BinaryBuffer& bb, Set& m) + { + size_t s; + diy::load(bb, s); + for (size_t i = 0; i < s; ++i) + { + T p; + diy::load(bb, p); + m.emplace(std::move(p)); + } + } + }; + + // save/load for std::tuple<...> + // TODO: this ought to be default (copying) serialization + // if all arguments are default + template + struct Serialization< std::tuple > + { + typedef std::tuple Tuple; + + static void save(BinaryBuffer& bb, const Tuple& t) { save<0>(bb, t); } + + template + static + typename std::enable_if::type + save(BinaryBuffer&, const Tuple&) {} + + template + static + typename std::enable_if::type + save(BinaryBuffer& bb, const Tuple& t) { diy::save(bb, std::get(t)); save(bb, t); } + + static void load(BinaryBuffer& bb, Tuple& t) { load<0>(bb, t); } + + template + static + typename std::enable_if::type + load(BinaryBuffer&, Tuple&) {} + + template + static + typename std::enable_if::type + load(BinaryBuffer& bb, Tuple& t) { diy::load(bb, std::get(t)); load(bb, t); } + + }; +} + +void +diy::MemoryBuffer:: +save_binary(const char* x, size_t count) +{ + if (position + count > buffer.capacity()) + buffer.reserve((position + count) * growth_multiplier()); // if we have to grow, grow geometrically + + if (position + count > buffer.size()) + buffer.resize(position + count); + + std::copy(x, x + count, &buffer[position]); + position += count; +} + +void +diy::MemoryBuffer:: +load_binary(char* x, size_t count) +{ + std::copy(&buffer[position], &buffer[position + count], x); + position += count; +} + +void +diy::MemoryBuffer:: +load_binary_back(char* x, size_t count) +{ + std::copy(&buffer[buffer.size() - count], &buffer[buffer.size()], x); + buffer.resize(buffer.size() - count); +} + +void +diy::MemoryBuffer:: +copy(MemoryBuffer& from, MemoryBuffer& to) +{ + size_t sz; + diy::load(from, sz); + from.position -= sizeof(size_t); + + size_t total = sizeof(size_t) + sz; + to.buffer.resize(to.position + total); + std::copy(&from.buffer[from.position], &from.buffer[from.position + total], &to.buffer[to.position]); + to.position += total; + from.position += total; +} + +#endif diff --git a/diy/include/diy/stats.hpp b/diy/include/diy/stats.hpp new file mode 100644 index 000000000..0628146df --- /dev/null +++ b/diy/include/diy/stats.hpp @@ -0,0 +1,120 @@ +#ifndef DIY_STATS_HPP +#define DIY_STATS_HPP + +#include +#include +#include + +#include "log.hpp" // need this for format +#define DIY_PROFILE 1 +namespace diy +{ +namespace stats +{ + +#if defined(DIY_PROFILE) +struct Profiler +{ + using Clock = std::chrono::high_resolution_clock; + using Time = Clock::time_point; + + struct Event + { + Event(const std::string& name_, bool begin_): + name(name_), + begin(begin_), + stamp(Clock::now()) + {} + + std::string name; + bool begin; + Time stamp; + }; + + using EventsVector = std::vector; + + struct Scoped + { + Scoped(Profiler& prof_, std::string name_): + prof(prof_), name(name_), active(true) { prof << name; } + ~Scoped() { if (active) prof >> name; } + + Scoped(Scoped&& other): + prof(other.prof), + name(other.name), + active(other.active) { other.active = false; } + + Scoped& + operator=(Scoped&& other) = delete; + Scoped(const Scoped&) = delete; + Scoped& + operator=(const Scoped&) = delete; + + Profiler& prof; + std::string name; + bool active; + }; + + Profiler() { reset_time(); } + + void reset_time() { start = Clock::now(); } + + void operator<<(std::string name) { enter(name); } + void operator>>(std::string name) { exit(name); } + + void enter(std::string name) { events.push_back(Event(name, true)); } + void exit(std::string name) { events.push_back(Event(name, false)); } + + void output(std::ostream& out) + { + for (size_t i = 0; i < events.size(); ++i) + { + const Event& e = events[i]; + auto time = std::chrono::duration_cast(e.stamp - start).count(); + fmt::print(out, "{} {} {}\n", + time / 1000000., + (e.begin ? '<' : '>'), + e.name); + /* + fmt::print(out, "{:02d}:{:02d}:{:02d}.{:06d} {}{}\n", + time/1000000/60/60, + time/1000000/60 % 60, + time/1000000 % 60, + time % 1000000, + (e.begin ? '<' : '>'), + e.name); + */ + } + } + + Scoped scoped(std::string name) { return Scoped(*this, name); } + + void clear() { events.clear(); } + + private: + Time start; + EventsVector events; +}; +#else +struct Profiler +{ + struct Scoped {}; + + void reset_time() {} + + void operator<<(std::string) {} + void operator>>(std::string) {} + + void enter(const std::string&) {} + void exit(const std::string&) {} + + void output(std::ostream&) {} + void clear() {} + + Scoped scoped(std::string) { return Scoped(); } +}; +#endif +} +} + +#endif diff --git a/diy/include/diy/storage.hpp b/diy/include/diy/storage.hpp new file mode 100644 index 000000000..62213b2c5 --- /dev/null +++ b/diy/include/diy/storage.hpp @@ -0,0 +1,228 @@ +#ifndef DIY_STORAGE_HPP +#define DIY_STORAGE_HPP + +#include +#include +#include + +#include // mkstemp() on Mac +#include // mkstemp() on Linux +#include // remove() +#include + +#include "serialization.hpp" +#include "thread.hpp" +#include "log.hpp" + +namespace diy +{ + namespace detail + { + typedef void (*Save)(const void*, BinaryBuffer& buf); + typedef void (*Load)(void*, BinaryBuffer& buf); + + struct FileBuffer: public BinaryBuffer + { + FileBuffer(FILE* file_): file(file_), head(0), tail(0) {} + + // TODO: add error checking + virtual inline void save_binary(const char* x, size_t count) override { fwrite(x, 1, count, file); head += count; } + virtual inline void load_binary(char* x, size_t count) override { fread(x, 1, count, file); } + virtual inline void load_binary_back(char* x, size_t count) override { fseek(file, tail, SEEK_END); fread(x, 1, count, file); tail += count; fseek(file, head, SEEK_SET); } + + size_t size() const { return head; } + + FILE* file; + size_t head, tail; // tail is used to support reading from the back; + // the mechanism is a little awkward and unused, but should work if needed + }; + } + + class ExternalStorage + { + public: + virtual int put(MemoryBuffer& bb) =0; + virtual int put(const void* x, detail::Save save) =0; + virtual void get(int i, MemoryBuffer& bb, size_t extra = 0) =0; + virtual void get(int i, void* x, detail::Load load) =0; + virtual void destroy(int i) =0; + }; + + class FileStorage: public ExternalStorage + { + private: + struct FileRecord + { + size_t size; + std::string name; + }; + + public: + FileStorage(const std::string& filename_template = "/tmp/DIY.XXXXXX"): + filename_templates_(1, filename_template), + count_(0), current_size_(0), max_size_(0) {} + + FileStorage(const std::vector& filename_templates): + filename_templates_(filename_templates), + count_(0), current_size_(0), max_size_(0) {} + + virtual int put(MemoryBuffer& bb) override + { + auto log = get_logger(); + std::string filename; + int fh = open_random(filename); + + log->debug("FileStorage::put(): {}; buffer size: {}", filename, bb.size()); + + size_t sz = bb.buffer.size(); + size_t written = write(fh, &bb.buffer[0], sz); + if (written < sz || written == (size_t)-1) + log->warn("Could not write the full buffer to {}: written = {}; size = {}", filename, written, sz); + fsync(fh); + close(fh); + bb.wipe(); + +#if 0 // double-check the written file size: only for extreme debugging + FILE* fp = fopen(filename.c_str(), "r"); + fseek(fp, 0L, SEEK_END); + int fsz = ftell(fp); + if (fsz != sz) + log->warn("file size doesn't match the buffer size, {} vs {}", fsz, sz); + fclose(fp); +#endif + + return make_file_record(filename, sz); + } + + virtual int put(const void* x, detail::Save save) override + { + std::string filename; + int fh = open_random(filename); + + detail::FileBuffer fb(fdopen(fh, "w")); + save(x, fb); + size_t sz = fb.size(); + fclose(fb.file); + fsync(fh); + + return make_file_record(filename, sz); + } + + virtual void get(int i, MemoryBuffer& bb, size_t extra) override + { + FileRecord fr = extract_file_record(i); + + get_logger()->debug("FileStorage::get(): {}", fr.name); + + bb.buffer.reserve(fr.size + extra); + bb.buffer.resize(fr.size); + int fh = open(fr.name.c_str(), O_RDONLY | O_SYNC, 0600); + read(fh, &bb.buffer[0], fr.size); + close(fh); + + remove_file(fr); + } + + virtual void get(int i, void* x, detail::Load load) override + { + FileRecord fr = extract_file_record(i); + + //int fh = open(fr.name.c_str(), O_RDONLY | O_SYNC, 0600); + int fh = open(fr.name.c_str(), O_RDONLY, 0600); + detail::FileBuffer fb(fdopen(fh, "r")); + load(x, fb); + fclose(fb.file); + + remove_file(fr); + } + + virtual void destroy(int i) override + { + FileRecord fr; + { + CriticalMapAccessor accessor = filenames_.access(); + fr = (*accessor)[i]; + accessor->erase(i); + } + remove(fr.name.c_str()); + (*current_size_.access()) -= fr.size; + } + + int count() const { return (*count_.const_access()); } + size_t current_size() const { return (*current_size_.const_access()); } + size_t max_size() const { return (*max_size_.const_access()); } + + ~FileStorage() + { + for (FileRecordMap::const_iterator it = filenames_.const_access()->begin(); + it != filenames_.const_access()->end(); + ++it) + { + remove(it->second.name.c_str()); + } + } + + private: + int open_random(std::string& filename) const + { + if (filename_templates_.size() == 1) + filename = filename_templates_[0].c_str(); + else + { + // pick a template at random (very basic load balancing mechanism) + filename = filename_templates_[std::rand() % filename_templates_.size()].c_str(); + } +#ifdef __MACH__ + // TODO: figure out how to open with O_SYNC + int fh = mkstemp(const_cast(filename.c_str())); +#else + int fh = mkostemp(const_cast(filename.c_str()), O_WRONLY | O_SYNC); +#endif + + return fh; + } + + int make_file_record(const std::string& filename, size_t sz) + { + int res = (*count_.access())++; + FileRecord fr = { sz, filename }; + (*filenames_.access())[res] = fr; + + // keep track of sizes + critical_resource::accessor cur = current_size_.access(); + *cur += sz; + critical_resource::accessor max = max_size_.access(); + if (*cur > *max) + *max = *cur; + + return res; + } + + FileRecord extract_file_record(int i) + { + CriticalMapAccessor accessor = filenames_.access(); + FileRecord fr = (*accessor)[i]; + accessor->erase(i); + return fr; + } + + void remove_file(const FileRecord& fr) + { + remove(fr.name.c_str()); + (*current_size_.access()) -= fr.size; + } + + private: + typedef std::map FileRecordMap; + typedef critical_resource CriticalMap; + typedef CriticalMap::accessor CriticalMapAccessor; + + private: + std::vector filename_templates_; + CriticalMap filenames_; + critical_resource count_; + critical_resource current_size_, max_size_; + }; +} + +#endif diff --git a/diy/include/diy/thread.hpp b/diy/include/diy/thread.hpp new file mode 100644 index 000000000..1c9149a42 --- /dev/null +++ b/diy/include/diy/thread.hpp @@ -0,0 +1,31 @@ +#ifndef DIY_THREAD_H +#define DIY_THREAD_H + +#ifdef DIY_NO_THREADS +#include "no-thread.hpp" +#else + +#include "thread/fast_mutex.h" + +#include +#include + +namespace diy +{ + using std::thread; + using std::mutex; + using std::recursive_mutex; + namespace this_thread = std::this_thread; + + // TODO: replace with our own implementation using std::atomic_flag + using fast_mutex = tthread::fast_mutex; + + template + using lock_guard = std::unique_lock; +} + +#endif + +#include "critical-resource.hpp" + +#endif diff --git a/diy/include/diy/thread/fast_mutex.h b/diy/include/diy/thread/fast_mutex.h new file mode 100644 index 000000000..4d4b7cc43 --- /dev/null +++ b/diy/include/diy/thread/fast_mutex.h @@ -0,0 +1,248 @@ +/* -*- mode: c++; tab-width: 2; indent-tabs-mode: nil; -*- +Copyright (c) 2010-2012 Marcus Geelnard + +This software is provided 'as-is', without any express or implied +warranty. In no event will the authors be held liable for any damages +arising from the use of this software. + +Permission is granted to anyone to use this software for any purpose, +including commercial applications, and to alter it and redistribute it +freely, subject to the following restrictions: + + 1. The origin of this software must not be misrepresented; you must not + claim that you wrote the original software. If you use this software + in a product, an acknowledgment in the product documentation would be + appreciated but is not required. + + 2. Altered source versions must be plainly marked as such, and must not be + misrepresented as being the original software. + + 3. This notice may not be removed or altered from any source + distribution. +*/ + +#ifndef _FAST_MUTEX_H_ +#define _FAST_MUTEX_H_ + +/// @file + +// Which platform are we on? +#if !defined(_TTHREAD_PLATFORM_DEFINED_) + #if defined(_WIN32) || defined(__WIN32__) || defined(__WINDOWS__) + #define _TTHREAD_WIN32_ + #else + #define _TTHREAD_POSIX_ + #endif + #define _TTHREAD_PLATFORM_DEFINED_ +#endif + +// Check if we can support the assembly language level implementation (otherwise +// revert to the system API) +#if (defined(__GNUC__) && (defined(__i386__) || defined(__x86_64__))) || \ + (defined(_MSC_VER) && (defined(_M_IX86) || defined(_M_X64))) || \ + (defined(__GNUC__) && (defined(__ppc__))) + #define _FAST_MUTEX_ASM_ +#else + #define _FAST_MUTEX_SYS_ +#endif + +#if defined(_TTHREAD_WIN32_) + #ifndef WIN32_LEAN_AND_MEAN + #define WIN32_LEAN_AND_MEAN + #define __UNDEF_LEAN_AND_MEAN + #endif + #include + #ifdef __UNDEF_LEAN_AND_MEAN + #undef WIN32_LEAN_AND_MEAN + #undef __UNDEF_LEAN_AND_MEAN + #endif +#else + #ifdef _FAST_MUTEX_ASM_ + #include + #else + #include + #endif +#endif + +namespace tthread { + +/// Fast mutex class. +/// This is a mutual exclusion object for synchronizing access to shared +/// memory areas for several threads. It is similar to the tthread::mutex class, +/// but instead of using system level functions, it is implemented as an atomic +/// spin lock with very low CPU overhead. +/// +/// The \c fast_mutex class is NOT compatible with the \c condition_variable +/// class (however, it IS compatible with the \c lock_guard class). It should +/// also be noted that the \c fast_mutex class typically does not provide +/// as accurate thread scheduling as a the standard \c mutex class does. +/// +/// Because of the limitations of the class, it should only be used in +/// situations where the mutex needs to be locked/unlocked very frequently. +/// +/// @note The "fast" version of this class relies on inline assembler language, +/// which is currently only supported for 32/64-bit Intel x86/AMD64 and +/// PowerPC architectures on a limited number of compilers (GNU g++ and MS +/// Visual C++). +/// For other architectures/compilers, system functions are used instead. +class fast_mutex { + public: + /// Constructor. +#if defined(_FAST_MUTEX_ASM_) + fast_mutex() : mLock(0) {} +#else + fast_mutex() + { + #if defined(_TTHREAD_WIN32_) + InitializeCriticalSection(&mHandle); + #elif defined(_TTHREAD_POSIX_) + pthread_mutex_init(&mHandle, NULL); + #endif + } +#endif + +#if !defined(_FAST_MUTEX_ASM_) + /// Destructor. + ~fast_mutex() + { + #if defined(_TTHREAD_WIN32_) + DeleteCriticalSection(&mHandle); + #elif defined(_TTHREAD_POSIX_) + pthread_mutex_destroy(&mHandle); + #endif + } +#endif + + /// Lock the mutex. + /// The method will block the calling thread until a lock on the mutex can + /// be obtained. The mutex remains locked until \c unlock() is called. + /// @see lock_guard + inline void lock() + { +#if defined(_FAST_MUTEX_ASM_) + bool gotLock; + do { + gotLock = try_lock(); + if(!gotLock) + { + #if defined(_TTHREAD_WIN32_) + Sleep(0); + #elif defined(_TTHREAD_POSIX_) + sched_yield(); + #endif + } + } while(!gotLock); +#else + #if defined(_TTHREAD_WIN32_) + EnterCriticalSection(&mHandle); + #elif defined(_TTHREAD_POSIX_) + pthread_mutex_lock(&mHandle); + #endif +#endif + } + + /// Try to lock the mutex. + /// The method will try to lock the mutex. If it fails, the function will + /// return immediately (non-blocking). + /// @return \c true if the lock was acquired, or \c false if the lock could + /// not be acquired. + inline bool try_lock() + { +#if defined(_FAST_MUTEX_ASM_) + int oldLock; + #if defined(__GNUC__) && (defined(__i386__) || defined(__x86_64__)) + asm volatile ( + "movl $1,%%eax\n\t" + "xchg %%eax,%0\n\t" + "movl %%eax,%1\n\t" + : "=m" (mLock), "=m" (oldLock) + : + : "%eax", "memory" + ); + #elif defined(_MSC_VER) && (defined(_M_IX86) || defined(_M_X64)) + int *ptrLock = &mLock; + __asm { + mov eax,1 + mov ecx,ptrLock + xchg eax,[ecx] + mov oldLock,eax + } + #elif defined(__GNUC__) && (defined(__ppc__)) + int newLock = 1; + asm volatile ( + "\n1:\n\t" + "lwarx %0,0,%1\n\t" + "cmpwi 0,%0,0\n\t" + "bne- 2f\n\t" + "stwcx. %2,0,%1\n\t" + "bne- 1b\n\t" + "isync\n" + "2:\n\t" + : "=&r" (oldLock) + : "r" (&mLock), "r" (newLock) + : "cr0", "memory" + ); + #endif + return (oldLock == 0); +#else + #if defined(_TTHREAD_WIN32_) + return TryEnterCriticalSection(&mHandle) ? true : false; + #elif defined(_TTHREAD_POSIX_) + return (pthread_mutex_trylock(&mHandle) == 0) ? true : false; + #endif +#endif + } + + /// Unlock the mutex. + /// If any threads are waiting for the lock on this mutex, one of them will + /// be unblocked. + inline void unlock() + { +#if defined(_FAST_MUTEX_ASM_) + #if defined(__GNUC__) && (defined(__i386__) || defined(__x86_64__)) + asm volatile ( + "movl $0,%%eax\n\t" + "xchg %%eax,%0\n\t" + : "=m" (mLock) + : + : "%eax", "memory" + ); + #elif defined(_MSC_VER) && (defined(_M_IX86) || defined(_M_X64)) + int *ptrLock = &mLock; + __asm { + mov eax,0 + mov ecx,ptrLock + xchg eax,[ecx] + } + #elif defined(__GNUC__) && (defined(__ppc__)) + asm volatile ( + "sync\n\t" // Replace with lwsync where possible? + : : : "memory" + ); + mLock = 0; + #endif +#else + #if defined(_TTHREAD_WIN32_) + LeaveCriticalSection(&mHandle); + #elif defined(_TTHREAD_POSIX_) + pthread_mutex_unlock(&mHandle); + #endif +#endif + } + + private: +#if defined(_FAST_MUTEX_ASM_) + int mLock; +#else + #if defined(_TTHREAD_WIN32_) + CRITICAL_SECTION mHandle; + #elif defined(_TTHREAD_POSIX_) + pthread_mutex_t mHandle; + #endif +#endif +}; + +} + +#endif // _FAST_MUTEX_H_ + diff --git a/diy/include/diy/time.hpp b/diy/include/diy/time.hpp new file mode 100644 index 000000000..d6b44c2e1 --- /dev/null +++ b/diy/include/diy/time.hpp @@ -0,0 +1,33 @@ +#ifndef DIY_TIME_HPP +#define DIY_TIME_HPP + +#include + +#ifdef __MACH__ +#include +#include +#endif + +namespace diy +{ + +typedef unsigned long time_type; + +inline time_type get_time() +{ +#ifdef __MACH__ // OS X does not have clock_gettime, use clock_get_time + clock_serv_t cclock; + mach_timespec_t ts; + host_get_clock_service(mach_host_self(), CALENDAR_CLOCK, &cclock); + clock_get_time(cclock, &ts); + mach_port_deallocate(mach_task_self(), cclock); +#else + timespec ts; + clock_gettime(CLOCK_REALTIME, &ts); +#endif + return ts.tv_sec*1000 + ts.tv_nsec/1000000; +} + +} + +#endif diff --git a/diy/include/diy/types.hpp b/diy/include/diy/types.hpp new file mode 100644 index 000000000..d52e75030 --- /dev/null +++ b/diy/include/diy/types.hpp @@ -0,0 +1,85 @@ +#ifndef DIY_TYPES_HPP +#define DIY_TYPES_HPP + +#include +#include "constants.h" +#include "point.hpp" + +namespace diy +{ + struct BlockID + { + int gid, proc; + }; + + template + struct Bounds + { + using Coordinate = Coordinate_; + + Point min, max; + }; + using DiscreteBounds = Bounds; + using ContinuousBounds = Bounds; + + //! Helper to create a 1-dimensional discrete domain with the specified extents + inline + diy::DiscreteBounds + interval(int from, int to) { DiscreteBounds domain; domain.min[0] = from; domain.max[0] = to; return domain; } + + struct Direction: public Point + { + Direction() { for (int i = 0; i < DIY_MAX_DIM; ++i) (*this)[i] = 0; } + Direction(int dir) + { + for (int i = 0; i < DIY_MAX_DIM; ++i) (*this)[i] = 0; + if (dir & DIY_X0) (*this)[0] -= 1; + if (dir & DIY_X1) (*this)[0] += 1; + if (dir & DIY_Y0) (*this)[1] -= 1; + if (dir & DIY_Y1) (*this)[1] += 1; + if (dir & DIY_Z0) (*this)[2] -= 1; + if (dir & DIY_Z1) (*this)[2] += 1; + if (dir & DIY_T0) (*this)[3] -= 1; + if (dir & DIY_T1) (*this)[3] += 1; + } + + bool + operator==(const diy::Direction& y) const + { + for (int i = 0; i < DIY_MAX_DIM; ++i) + if ((*this)[i] != y[i]) return false; + return true; + } + + // lexicographic comparison + bool + operator<(const diy::Direction& y) const + { + for (int i = 0; i < DIY_MAX_DIM; ++i) + { + if ((*this)[i] < y[i]) return true; + if ((*this)[i] > y[i]) return false; + } + return false; + } + }; + + // Selector of bounds value type + template + struct BoundsValue + { + using type = typename Bounds_::Coordinate; + }; + + inline + bool + operator<(const diy::BlockID& x, const diy::BlockID& y) + { return x.gid < y.gid; } + + inline + bool + operator==(const diy::BlockID& x, const diy::BlockID& y) + { return x.gid == y.gid; } +} + +#endif diff --git a/diy/include/diy/vertices.hpp b/diy/include/diy/vertices.hpp new file mode 100644 index 000000000..423209fd6 --- /dev/null +++ b/diy/include/diy/vertices.hpp @@ -0,0 +1,54 @@ +#ifndef DIY_VERTICES_HPP +#define DIY_VERTICES_HPP + +#include + +namespace diy +{ + +namespace detail +{ + template + struct IsLast + { + static constexpr bool value = (Vertex::dimension() - 1 == I); + }; + + template + struct ForEach + { + void operator()(Vertex& pos, const Vertex& from, const Vertex& to, const Callback& callback) const + { + for (pos[I] = from[I]; pos[I] <= to[I]; ++pos[I]) + ForEach::value>()(pos, from, to, callback); + } + }; + + template + struct ForEach + { + void operator()(Vertex& pos, const Vertex& from, const Vertex& to, const Callback& callback) const + { + for (pos[I] = from[I]; pos[I] <= to[I]; ++pos[I]) + callback(pos); + } + }; +} + +template +void for_each(const Vertex& from, const Vertex& to, const Callback& callback) +{ + Vertex pos; + grid::detail::ForEach::value>()(pos, from, to, callback); +} + +template +void for_each(const Vertex& shape, const Callback& callback) +{ + // specify grid namespace to disambiguate with std::for_each(...) + grid::for_each(Vertex::zero(), shape - Vertex::one(), callback); +} + +} + +#endif diff --git a/examples/game_of_life/GameOfLife.cxx b/examples/game_of_life/GameOfLife.cxx index 304ceb6d5..bdd14e526 100644 --- a/examples/game_of_life/GameOfLife.cxx +++ b/examples/game_of_life/GameOfLife.cxx @@ -357,7 +357,7 @@ int main(int argc, char** argv) vtkm::cont::DataSetBuilderUniform builder; vtkm::cont::DataSet data = builder.Create(vtkm::Id2(x, y)); - vtkm::cont::Field stateField("state", vtkm::cont::Field::ASSOC_POINTS, input_state); + auto stateField = vtkm::cont::make_Field("state", vtkm::cont::Field::ASSOC_POINTS, input_state); data.AddField(stateField); GameOfLife filter; diff --git a/vtkm/CMakeLists.txt b/vtkm/CMakeLists.txt index 443db458f..f48d25794 100644 --- a/vtkm/CMakeLists.txt +++ b/vtkm/CMakeLists.txt @@ -34,6 +34,7 @@ set(headers Bounds.h CellShape.h CellTraits.h + Flags.h Hash.h ImplicitFunction.h ListTag.h diff --git a/vtkm/Flags.h b/vtkm/Flags.h new file mode 100644 index 000000000..ba90d3f66 --- /dev/null +++ b/vtkm/Flags.h @@ -0,0 +1,33 @@ +//============================================================================ +// Copyright (c) Kitware, Inc. +// All rights reserved. +// See LICENSE.txt for details. +// This software is distributed WITHOUT ANY WARRANTY; without even +// the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR +// PURPOSE. See the above copyright notice for more information. +// +// Copyright 2014 National Technology & Engineering Solutions of Sandia, LLC (NTESS). +// Copyright 2014 UT-Battelle, LLC. +// Copyright 2014 Los Alamos National Security. +// +// Under the terms of Contract DE-NA0003525 with NTESS, +// the U.S. Government retains certain rights in this software. +// +// Under the terms of Contract DE-AC52-06NA25396 with Los Alamos National +// Laboratory (LANL), the U.S. Government retains certain rights in +// this software. +//============================================================================ +#ifndef vtk_m_Flags_h +#define vtk_m_Flags_h + +namespace vtkm +{ + +enum class CopyFlag +{ + Off = 0, + On = 1 +}; +} + +#endif // vtk_m_Flags_h diff --git a/vtkm/ListTag.h b/vtkm/ListTag.h index afcd6a5ac..691e27c2a 100644 --- a/vtkm/ListTag.h +++ b/vtkm/ListTag.h @@ -94,7 +94,7 @@ VTKM_CONT void ListForEach(Functor&& f, ListTag, Args&&... args) } /// Generate a tag that is the cross product of two other tags. The resulting -// a tag has the form of Tag< std::pair, std::pair .... > +// a tag has the form of Tag< brigand::list, brigand::list .... > /// template struct ListCrossProduct : detail::ListRoot diff --git a/vtkm/benchmarking/BenchmarkRayTracing.cxx b/vtkm/benchmarking/BenchmarkRayTracing.cxx new file mode 100644 index 000000000..c32eb137f --- /dev/null +++ b/vtkm/benchmarking/BenchmarkRayTracing.cxx @@ -0,0 +1,113 @@ +//============================================================================ +// Copyright (c) Kitware, Inc. +// All rights reserved. +// See LICENSE.txt for details. +// This software is distributed WITHOUT ANY WARRANTY; without even +// the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR +// PURPOSE. See the above copyright notice for more information. +// +// Copyright 2017 National Technology & Engineering Solutions of Sandia, LLC (NTESS). +// Copyright 2017 UT-Battelle, LLC. +// Copyright 2017 Los Alamos National Security. +// +// Under the terms of Contract DE-NA0003525 with NTESS, +// the U.S. Government retains certain rights in this software. +// +// Under the terms of Contract DE-AC52-06NA25396 with Los Alamos National +// Laboratory (LANL), the U.S. Government retains certain rights in +// this software. +//============================================================================ + +#include + +#include + +#include +#include +#include +#include + +#include +#include +#include +#include + +#include + +#include +#include +#include + +using namespace vtkm::benchmarking; +namespace vtkm +{ +namespace benchmarking +{ + +template +struct BenchRayTracing +{ + vtkm::rendering::raytracing::RayTracer Tracer; + vtkm::rendering::raytracing::Camera RayCamera; + vtkm::cont::ArrayHandle> Indices; + vtkm::rendering::raytracing::Ray Rays; + vtkm::Id NumberOfTriangles; + vtkm::cont::CoordinateSystem Coords; + vtkm::cont::DataSet Data; + + VTKM_CONT BenchRayTracing() + { + vtkm::cont::testing::MakeTestDataSet maker; + Data = maker.Make3DUniformDataSet2(); + Coords = Data.GetCoordinateSystem(); + + vtkm::rendering::Camera camera; + vtkm::Bounds bounds = Data.GetCoordinateSystem().GetBounds(); + camera.ResetToBounds(bounds); + + vtkm::cont::DynamicCellSet cellset = Data.GetCellSet(); + vtkm::rendering::internal::RunTriangulator(cellset, Indices, NumberOfTriangles); + + vtkm::rendering::CanvasRayTracer canvas(1920, 1080); + RayCamera.SetParameters(camera, canvas); + RayCamera.CreateRays(Rays, Coords); + + Rays.Buffers.at(0).InitConst(0.f); + + vtkm::cont::Field field = Data.GetField("pointvar"); + vtkm::Range range = field.GetRange().GetPortalConstControl().Get(0); + + Tracer.SetData(Coords.GetData(), Indices, field, NumberOfTriangles, range, bounds); + + vtkm::cont::ArrayHandle> colors; + vtkm::rendering::ColorTable("cool2warm").Sample(100, colors); + + Tracer.SetColorMap(colors); + Tracer.Render(Rays); + } + + VTKM_CONT + vtkm::Float64 operator()() + { + vtkm::cont::Timer timer; + + RayCamera.CreateRays(Rays, Coords); + Tracer.Render(Rays); + + return timer.GetElapsedTime(); + } + + VTKM_CONT + std::string Description() const { return "A ray tracing benchmark"; } +}; + +VTKM_MAKE_BENCHMARK(RayTracing, BenchRayTracing); +} +} // end namespace vtkm::benchmarking + +int main(int, char* []) +{ + using TestTypes = vtkm::ListTagBase; + VTKM_RUN_BENCHMARK(RayTracing, vtkm::ListTagBase()); + return 0; +} diff --git a/vtkm/benchmarking/CMakeLists.txt b/vtkm/benchmarking/CMakeLists.txt index 67b7824d5..c8661f844 100644 --- a/vtkm/benchmarking/CMakeLists.txt +++ b/vtkm/benchmarking/CMakeLists.txt @@ -25,16 +25,10 @@ set(benchmarks BenchmarkFieldAlgorithms BenchmarkTopologyAlgorithms ) -#set(benchmark_files -# BenchmarkArrayTransfer.cxx -# BenchmarkCopySpeeds.cxx -# BenchmarkDeviceAdapter.cxx -# BenchmarkFieldAlgorithms.cxx -# BenchmarkTopologyAlgorithms.cxx -# ) -#set(benchmark_headers -# Benchmarker.h -# ) + +if(TARGET vtkm_rendering) + list(APPEND benchmarks BenchmarkRayTracing) +endif() function(add_benchmark name files) add_executable(${name}_SERIAL ${files}) diff --git a/vtkm/cont/Algorithm.h b/vtkm/cont/Algorithm.h new file mode 100644 index 000000000..833f19ed0 --- /dev/null +++ b/vtkm/cont/Algorithm.h @@ -0,0 +1,537 @@ +//============================================================================ +// Copyright (c) Kitware, Inc. +// All rights reserved. +// See LICENSE.txt for details. +// This software is distributed WITHOUT ANY WARRANTY; without even +// the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR +// PURPOSE. See the above copyright notice for more information. +// +// Copyright 2014 National Technology & Engineering Solutions of Sandia, LLC (NTESS). +// Copyright 2014 UT-Battelle, LLC. +// Copyright 2014 Los Alamos National Security. +// +// Under the terms of Contract DE-NA0003525 with NTESS, +// the U.S. Government retains certain rights in this software. +// +// Under the terms of Contract DE-AC52-06NA25396 with Los Alamos National +// Laboratory (LANL), the U.S. Government retains certain rights in +// this software. +//============================================================================ +#ifndef vtk_m_cont_Algorithm_h +#define vtk_m_cont_Algorithm_h + +#include + +#include +#include +#include + +namespace vtkm +{ +namespace cont +{ + +namespace +{ +struct CopyFunctor +{ + template + VTKM_CONT bool operator()(Device, Args&&... args) const + { + VTKM_IS_DEVICE_ADAPTER_TAG(Device); + vtkm::cont::DeviceAdapterAlgorithm::Copy(std::forward(args)...); + return true; + } +}; + +struct CopyIfFunctor +{ + + template + VTKM_CONT bool operator()(Device, Args&&... args) const + { + VTKM_IS_DEVICE_ADAPTER_TAG(Device); + vtkm::cont::DeviceAdapterAlgorithm::CopyIf(std::forward(args)...); + return true; + } +}; + +struct CopySubRangeFunctor +{ + bool valid; + + template + VTKM_CONT bool operator()(Device, Args&&... args) + { + VTKM_IS_DEVICE_ADAPTER_TAG(Device); + valid = vtkm::cont::DeviceAdapterAlgorithm::CopySubRange(std::forward(args)...); + return true; + } +}; + +struct LowerBoundsFunctor +{ + + template + VTKM_CONT bool operator()(Device, Args&&... args) const + { + VTKM_IS_DEVICE_ADAPTER_TAG(Device); + vtkm::cont::DeviceAdapterAlgorithm::LowerBounds(std::forward(args)...); + return true; + } +}; + +template +struct ReduceFunctor +{ + U result; + + ReduceFunctor() + : result(U(0)) + { + } + + template + VTKM_CONT bool operator()(Device, Args&&... args) + { + VTKM_IS_DEVICE_ADAPTER_TAG(Device); + result = vtkm::cont::DeviceAdapterAlgorithm::Reduce(std::forward(args)...); + return true; + } +}; + +struct ReduceByKeyFunctor +{ + template + VTKM_CONT bool operator()(Device, Args&&... args) const + { + VTKM_IS_DEVICE_ADAPTER_TAG(Device); + vtkm::cont::DeviceAdapterAlgorithm::ReduceByKey(std::forward(args)...); + return true; + } +}; + +template +struct ScanInclusiveResultFunctor +{ + T result; + ScanInclusiveResultFunctor() + : result(T(0)) + { + } + + template + VTKM_CONT bool operator()(Device, Args&&... args) + { + VTKM_IS_DEVICE_ADAPTER_TAG(Device); + result = vtkm::cont::DeviceAdapterAlgorithm::ScanInclusive(std::forward(args)...); + return true; + } +}; + +template +struct StreamingScanExclusiveFunctor +{ + T result; + StreamingScanExclusiveFunctor() + : result(T(0)) + { + } + + template + VTKM_CONT bool operator()(Device, + const vtkm::Id numBlocks, + const vtkm::cont::ArrayHandle& input, + vtkm::cont::ArrayHandle& output) + { + VTKM_IS_DEVICE_ADAPTER_TAG(Device); + result = + vtkm::cont::DeviceAdapterAlgorithm::StreamingScanExclusive(numBlocks, input, output); + return true; + } +}; + +struct ScanInclusiveByKeyFunctor +{ + ScanInclusiveByKeyFunctor() {} + + template + VTKM_CONT bool operator()(Device, Args&&... args) const + { + VTKM_IS_DEVICE_ADAPTER_TAG(Device); + vtkm::cont::DeviceAdapterAlgorithm::ScanInclusiveByKey(std::forward(args)...); + return true; + } +}; + +template +struct ScanExclusiveFunctor +{ + T result; + ScanExclusiveFunctor() + : result(T(0)) + { + } + + template + VTKM_CONT bool operator()(Device, Args&&... args) + { + VTKM_IS_DEVICE_ADAPTER_TAG(Device); + result = vtkm::cont::DeviceAdapterAlgorithm::ScanExclusive(std::forward(args)...); + return true; + } +}; + +struct ScanExclusiveByKeyFunctor +{ + ScanExclusiveByKeyFunctor() {} + + template + VTKM_CONT bool operator()(Device, Args&&... args) const + { + VTKM_IS_DEVICE_ADAPTER_TAG(Device); + vtkm::cont::DeviceAdapterAlgorithm::ScanExclusiveByKey(std::forward(args)...); + return true; + } +}; + +struct ScheduleFunctor +{ + template + VTKM_CONT bool operator()(Device, Args&&... args) + { + VTKM_IS_DEVICE_ADAPTER_TAG(Device); + vtkm::cont::DeviceAdapterAlgorithm::Schedule(std::forward(args)...); + return true; + } +}; + +struct SortFunctor +{ + template + VTKM_CONT bool operator()(Device, Args&&... args) const + { + VTKM_IS_DEVICE_ADAPTER_TAG(Device); + vtkm::cont::DeviceAdapterAlgorithm::Sort(std::forward(args)...); + return true; + } +}; + +struct SortByKeyFunctor +{ + template + VTKM_CONT bool operator()(Device, Args&&... args) const + { + VTKM_IS_DEVICE_ADAPTER_TAG(Device); + vtkm::cont::DeviceAdapterAlgorithm::SortByKey(std::forward(args)...); + return true; + } +}; + +struct SynchronizeFunctor +{ + template + VTKM_CONT bool operator()(Device) + { + VTKM_IS_DEVICE_ADAPTER_TAG(Device); + vtkm::cont::DeviceAdapterAlgorithm::Synchronize(); + return true; + } +}; + +struct UniqueFunctor +{ + template + VTKM_CONT bool operator()(Device, Args&&... args) const + { + VTKM_IS_DEVICE_ADAPTER_TAG(Device); + vtkm::cont::DeviceAdapterAlgorithm::Unique(std::forward(args)...); + return true; + } +}; + +struct UpperBoundsFunctor +{ + template + VTKM_CONT bool operator()(Device, Args&&... args) const + { + VTKM_IS_DEVICE_ADAPTER_TAG(Device); + vtkm::cont::DeviceAdapterAlgorithm::UpperBounds(std::forward(args)...); + return true; + } +}; +} // annonymous namespace + +struct Algorithm +{ + + template + VTKM_CONT static void Copy(const vtkm::cont::ArrayHandle& input, + vtkm::cont::ArrayHandle& output) + { + vtkm::cont::TryExecute(CopyFunctor(), input, output); + } + + template + VTKM_CONT static void CopyIf(const vtkm::cont::ArrayHandle& input, + const vtkm::cont::ArrayHandle& stencil, + vtkm::cont::ArrayHandle& output) + { + vtkm::cont::TryExecute(CopyIfFunctor(), input, stencil, output); + } + + template + VTKM_CONT static void CopyIf(const vtkm::cont::ArrayHandle& input, + const vtkm::cont::ArrayHandle& stencil, + vtkm::cont::ArrayHandle& output, + UnaryPredicate unary_predicate) + { + vtkm::cont::TryExecute(CopyIfFunctor(), input, stencil, output, unary_predicate); + } + + template + VTKM_CONT static bool CopySubRange(const vtkm::cont::ArrayHandle& input, + vtkm::Id inputStartIndex, + vtkm::Id numberOfElementsToCopy, + vtkm::cont::ArrayHandle& output, + vtkm::Id outputIndex = 0) + { + CopySubRangeFunctor functor; + vtkm::cont::TryExecute( + functor, input, inputStartIndex, numberOfElementsToCopy, output, outputIndex); + return functor.valid; + } + + template + VTKM_CONT static void LowerBounds(const vtkm::cont::ArrayHandle& input, + const vtkm::cont::ArrayHandle& values, + vtkm::cont::ArrayHandle& output) + { + vtkm::cont::TryExecute(LowerBoundsFunctor(), input, values, output); + } + + template + VTKM_CONT static void LowerBounds(const vtkm::cont::ArrayHandle& input, + const vtkm::cont::ArrayHandle& values, + vtkm::cont::ArrayHandle& output, + BinaryCompare binary_compare) + { + vtkm::cont::TryExecute(LowerBoundsFunctor(), input, values, output, binary_compare); + } + + template + VTKM_CONT static void LowerBounds(const vtkm::cont::ArrayHandle& input, + vtkm::cont::ArrayHandle& values_output) + { + vtkm::cont::TryExecute(LowerBoundsFunctor(), input, values_output); + } + + template + VTKM_CONT static U Reduce(const vtkm::cont::ArrayHandle& input, U initialValue) + { + ReduceFunctor functor; + vtkm::cont::TryExecute(functor, input, initialValue); + return functor.result; + } + + template + VTKM_CONT static U Reduce(const vtkm::cont::ArrayHandle& input, + U initialValue, + BinaryFunctor binary_functor) + { + ReduceFunctor functor; + vtkm::cont::TryExecute(functor, input, initialValue, binary_functor); + return functor.result; + } + + template + VTKM_CONT static void ReduceByKey(const vtkm::cont::ArrayHandle& keys, + const vtkm::cont::ArrayHandle& values, + vtkm::cont::ArrayHandle& keys_output, + vtkm::cont::ArrayHandle& values_output, + BinaryFunctor binary_functor) + { + vtkm::cont::TryExecute( + ReduceByKeyFunctor(), keys, values, keys_output, values_output, binary_functor); + } + + template + VTKM_CONT static T ScanInclusive(const vtkm::cont::ArrayHandle& input, + vtkm::cont::ArrayHandle& output) + { + ScanInclusiveResultFunctor functor; + vtkm::cont::TryExecute(functor, input, output); + return functor.result; + } + + template + VTKM_CONT static T StreamingScanExclusive(const vtkm::Id numBlocks, + const vtkm::cont::ArrayHandle& input, + vtkm::cont::ArrayHandle& output) + { + StreamingScanExclusiveFunctor functor; + vtkm::cont::TryExecute(functor, numBlocks, input, output); + return functor.result; + } + + template + VTKM_CONT static T ScanInclusive(const vtkm::cont::ArrayHandle& input, + vtkm::cont::ArrayHandle& output, + BinaryFunctor binary_functor) + { + ScanInclusiveResultFunctor functor; + vtkm::cont::TryExecute(functor, input, output, binary_functor); + return functor.result; + } + + template + VTKM_CONT static void ScanInclusiveByKey(const vtkm::cont::ArrayHandle& keys, + const vtkm::cont::ArrayHandle& values, + vtkm::cont::ArrayHandle& values_output, + BinaryFunctor binary_functor) + { + vtkm::cont::TryExecute( + ScanInclusiveByKeyFunctor(), keys, values, values_output, binary_functor); + } + + template + VTKM_CONT static void ScanInclusiveByKey(const vtkm::cont::ArrayHandle& keys, + const vtkm::cont::ArrayHandle& values, + vtkm::cont::ArrayHandle& values_output) + { + vtkm::cont::TryExecute(ScanInclusiveByKeyFunctor(), keys, values, values_output); + } + + template + VTKM_CONT static T ScanExclusive(const vtkm::cont::ArrayHandle& input, + vtkm::cont::ArrayHandle& output) + { + ScanExclusiveFunctor functor; + vtkm::cont::TryExecute(functor, input, output); + return functor.result; + } + + template + VTKM_CONT static T ScanExclusive(const vtkm::cont::ArrayHandle& input, + vtkm::cont::ArrayHandle& output, + BinaryFunctor binaryFunctor, + const T& initialValue) + { + ScanExclusiveFunctor functor; + vtkm::cont::TryExecute(functor, input, output, binaryFunctor, initialValue); + return functor.result; + } + + template + VTKM_CONT static void ScanExclusiveByKey(const vtkm::cont::ArrayHandle& keys, + const vtkm::cont::ArrayHandle& values, + vtkm::cont::ArrayHandle& output, + const U& initialValue, + BinaryFunctor binaryFunctor) + { + vtkm::cont::TryExecute( + ScanExclusiveByKeyFunctor(), keys, values, output, initialValue, binaryFunctor); + } + + template + VTKM_CONT static void ScanExclusiveByKey(const vtkm::cont::ArrayHandle& keys, + const vtkm::cont::ArrayHandle& values, + vtkm::cont::ArrayHandle& output) + { + vtkm::cont::TryExecute(ScanExclusiveByKeyFunctor(), keys, values, output); + } + + template + VTKM_CONT static void Schedule(Functor functor, vtkm::Id numInstances) + { + vtkm::cont::TryExecute(ScheduleFunctor(), functor, numInstances); + } + + template + VTKM_CONT static void Schedule(Functor functor, vtkm::Id3 rangeMax) + { + vtkm::cont::TryExecute(ScheduleFunctor(), functor, rangeMax); + } + + template + VTKM_CONT static void Sort(vtkm::cont::ArrayHandle& values) + { + vtkm::cont::TryExecute(SortFunctor(), values); + } + + template + VTKM_CONT static void Sort(vtkm::cont::ArrayHandle& values, + BinaryCompare binary_compare) + { + vtkm::cont::TryExecute(SortFunctor(), values, binary_compare); + } + + template + VTKM_CONT static void SortByKey(vtkm::cont::ArrayHandle& keys, + vtkm::cont::ArrayHandle& values) + { + vtkm::cont::TryExecute(SortByKeyFunctor(), keys, values); + } + + template + VTKM_CONT static void SortByKey(vtkm::cont::ArrayHandle& keys, + vtkm::cont::ArrayHandle& values, + BinaryCompare binary_compare) + { + vtkm::cont::TryExecute(SortByKeyFunctor(), keys, values, binary_compare); + } + + VTKM_CONT static void Synchronize() { vtkm::cont::TryExecute(SynchronizeFunctor()); } + + template + VTKM_CONT static void Unique(vtkm::cont::ArrayHandle& values) + { + vtkm::cont::TryExecute(UniqueFunctor(), values); + } + + template + VTKM_CONT static void Unique(vtkm::cont::ArrayHandle& values, + BinaryCompare binary_compare) + { + vtkm::cont::TryExecute(UniqueFunctor(), values, binary_compare); + } + + template + VTKM_CONT static void UpperBounds(const vtkm::cont::ArrayHandle& input, + const vtkm::cont::ArrayHandle& values, + vtkm::cont::ArrayHandle& output) + { + vtkm::cont::TryExecute(UpperBoundsFunctor(), input, values, output); + } + + template + VTKM_CONT static void UpperBounds(const vtkm::cont::ArrayHandle& input, + const vtkm::cont::ArrayHandle& values, + vtkm::cont::ArrayHandle& output, + BinaryCompare binary_compare) + { + vtkm::cont::TryExecute(UpperBoundsFunctor(), input, values, output, binary_compare); + } + + template + VTKM_CONT static void UpperBounds(const vtkm::cont::ArrayHandle& input, + vtkm::cont::ArrayHandle& values_output) + { + vtkm::cont::TryExecute(UpperBoundsFunctor(), input, values_output); + } +}; +} +} // namespace vtkm::cont + +#endif //vtk_m_cont_Algorithm_h diff --git a/vtkm/cont/ArrayHandle.h b/vtkm/cont/ArrayHandle.h index 874e12175..09d910314 100644 --- a/vtkm/cont/ArrayHandle.h +++ b/vtkm/cont/ArrayHandle.h @@ -23,6 +23,7 @@ #include #include +#include #include #include @@ -31,6 +32,7 @@ #include #include +#include #include #include #include @@ -503,23 +505,35 @@ public: /// A convenience function for creating an ArrayHandle from a standard C array. /// template -VTKM_CONT vtkm::cont::ArrayHandle make_ArrayHandle(const T* array, - vtkm::Id length) +VTKM_CONT vtkm::cont::ArrayHandle +make_ArrayHandle(const T* array, vtkm::Id length, vtkm::CopyFlag copy = vtkm::CopyFlag::Off) { using ArrayHandleType = vtkm::cont::ArrayHandle; - using StorageType = vtkm::cont::internal::Storage; - return ArrayHandleType(StorageType(array, length)); + if (copy == vtkm::CopyFlag::On) + { + ArrayHandleType handle; + handle.Allocate(length); + std::copy( + array, array + length, vtkm::cont::ArrayPortalToIteratorBegin(handle.GetPortalControl())); + return handle; + } + else + { + using StorageType = vtkm::cont::internal::Storage; + return ArrayHandleType(StorageType(array, length)); + } } /// A convenience function for creating an ArrayHandle from an std::vector. /// template VTKM_CONT vtkm::cont::ArrayHandle make_ArrayHandle( - const std::vector& array) + const std::vector& array, + vtkm::CopyFlag copy = vtkm::CopyFlag::Off) { if (!array.empty()) { - return make_ArrayHandle(&array.front(), static_cast(array.size())); + return make_ArrayHandle(&array.front(), static_cast(array.size()), copy); } else { diff --git a/vtkm/cont/CMakeLists.txt b/vtkm/cont/CMakeLists.txt index 76cb68e74..8d09016ac 100644 --- a/vtkm/cont/CMakeLists.txt +++ b/vtkm/cont/CMakeLists.txt @@ -19,6 +19,7 @@ ##============================================================================ set(headers + Algorithm.h ArrayCopy.h ArrayHandle.h ArrayHandleCast.h @@ -62,6 +63,7 @@ set(headers DeviceAdapterListTag.h DynamicArrayHandle.h DynamicCellSet.h + EnvironmentTracker.h Error.h ErrorBadAllocation.h ErrorBadType.h @@ -96,7 +98,9 @@ set(sources CellSetExplicit.cxx CellSetStructured.cxx CoordinateSystem.cxx + DataSet.cxx DynamicArrayHandle.cxx + EnvironmentTracker.cxx Field.cxx internal/SimplePolymorphicContainer.cxx MultiBlock.cxx @@ -140,5 +144,10 @@ vtkm_library( NAME vtkm_cont WRAP_FOR_CUDA ${device_sources} ) target_link_libraries(vtkm_cont PUBLIC vtkm_compiler_flags ${backends}) +if(VTKm_ENABLE_MPI) + # This will become a required dependency eventually. + target_link_libraries(vtkm_cont PUBLIC diy) +endif() + #----------------------------------------------------------------------------- add_subdirectory(testing) diff --git a/vtkm/cont/CoordinateSystem.h b/vtkm/cont/CoordinateSystem.h index 1504a83bc..6bfa49745 100644 --- a/vtkm/cont/CoordinateSystem.h +++ b/vtkm/cont/CoordinateSystem.h @@ -110,18 +110,6 @@ public: { } - template - VTKM_CONT CoordinateSystem(std::string name, const std::vector& data) - : Superclass(name, ASSOC_POINTS, data) - { - } - - template - VTKM_CONT CoordinateSystem(std::string name, const T* data, vtkm::Id numberOfValues) - : Superclass(name, ASSOC_POINTS, data, numberOfValues) - { - } - /// This constructor of coordinate system sets up a regular grid of points. /// VTKM_CONT @@ -225,9 +213,27 @@ public: }; template -void CastAndCall(const vtkm::cont::CoordinateSystem& coords, const Functor& f, Args&&... args) +void CastAndCall(const vtkm::cont::CoordinateSystem& coords, Functor&& f, Args&&... args) { - coords.GetData().CastAndCall(f, std::forward(args)...); + coords.GetData().CastAndCall(std::forward(f), std::forward(args)...); +} + +template +vtkm::cont::CoordinateSystem make_CoordinateSystem(std::string name, + const std::vector& data, + vtkm::CopyFlag copy = vtkm::CopyFlag::Off) +{ + return vtkm::cont::CoordinateSystem(name, vtkm::cont::make_ArrayHandle(data, copy)); +} + +template +vtkm::cont::CoordinateSystem make_CoordinateSystem(std::string name, + const T* data, + vtkm::Id numberOfValues, + vtkm::CopyFlag copy = vtkm::CopyFlag::Off) +{ + return vtkm::cont::CoordinateSystem(name, + vtkm::cont::make_ArrayHandle(data, numberOfValues, copy)); } namespace internal diff --git a/vtkm/cont/DataSet.cxx b/vtkm/cont/DataSet.cxx new file mode 100644 index 000000000..3cfc36982 --- /dev/null +++ b/vtkm/cont/DataSet.cxx @@ -0,0 +1,164 @@ +//============================================================================ +// Copyright (c) Kitware, Inc. +// All rights reserved. +// See LICENSE.txt for details. +// This software is distributed WITHOUT ANY WARRANTY; without even +// the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR +// PURPOSE. See the above copyright notice for more information. +// +// Copyright 2015 National Technology & Engineering Solutions of Sandia, LLC (NTESS). +// Copyright 2015 UT-Battelle, LLC. +// Copyright 2015 Los Alamos National Security. +// +// Under the terms of Contract DE-NA0003525 with NTESS, +// the U.S. Government retains certain rights in this software. +// +// Under the terms of Contract DE-AC52-06NA25396 with Los Alamos National +// Laboratory (LANL), the U.S. Government retains certain rights in +// this software. +//============================================================================ + +#include + +namespace vtkm +{ +namespace cont +{ + +DataSet::DataSet() +{ +} + +void DataSet::Clear() +{ + this->CoordSystems.clear(); + this->Fields.clear(); + this->CellSets.clear(); +} + +const vtkm::cont::Field& DataSet::GetField(vtkm::Id index) const +{ + VTKM_ASSERT((index >= 0) && (index < this->GetNumberOfFields())); + return this->Fields[static_cast(index)]; +} + +vtkm::Id DataSet::GetFieldIndex(const std::string& name, + vtkm::cont::Field::AssociationEnum assoc) const +{ + bool found; + vtkm::Id index = this->FindFieldIndex(name, assoc, found); + if (found) + { + return index; + } + else + { + throw vtkm::cont::ErrorBadValue("No field with requested name: " + name); + } +} + +const vtkm::cont::CoordinateSystem& DataSet::GetCoordinateSystem(vtkm::Id index) const +{ + VTKM_ASSERT((index >= 0) && (index < this->GetNumberOfCoordinateSystems())); + return this->CoordSystems[static_cast(index)]; +} + +vtkm::Id DataSet::GetCoordinateSystemIndex(const std::string& name) const +{ + bool found; + vtkm::Id index = this->FindCoordinateSystemIndex(name, found); + if (found) + { + return index; + } + else + { + throw vtkm::cont::ErrorBadValue("No coordinate system with requested name"); + } +} + +vtkm::Id DataSet::GetCellSetIndex(const std::string& name) const +{ + bool found; + vtkm::Id index = this->FindCellSetIndex(name, found); + if (found) + { + return index; + } + else + { + throw vtkm::cont::ErrorBadValue("No cell set with requested name"); + } +} + +void DataSet::PrintSummary(std::ostream& out) const +{ + out << "DataSet:\n"; + out << " CoordSystems[" << this->CoordSystems.size() << "]\n"; + for (std::size_t index = 0; index < this->CoordSystems.size(); index++) + { + this->CoordSystems[index].PrintSummary(out); + } + + out << " CellSets[" << this->GetNumberOfCellSets() << "]\n"; + for (vtkm::Id index = 0; index < this->GetNumberOfCellSets(); index++) + { + this->GetCellSet(index).PrintSummary(out); + } + + out << " Fields[" << this->GetNumberOfFields() << "]\n"; + for (vtkm::Id index = 0; index < this->GetNumberOfFields(); index++) + { + this->GetField(index).PrintSummary(out); + } +} + +vtkm::Id DataSet::FindFieldIndex(const std::string& name, + vtkm::cont::Field::AssociationEnum association, + bool& found) const +{ + for (std::size_t index = 0; index < this->Fields.size(); ++index) + { + if ((association == vtkm::cont::Field::ASSOC_ANY || + association == this->Fields[index].GetAssociation()) && + this->Fields[index].GetName() == name) + { + found = true; + return static_cast(index); + } + } + found = false; + return -1; +} + + +vtkm::Id DataSet::FindCoordinateSystemIndex(const std::string& name, bool& found) const +{ + for (std::size_t index = 0; index < this->CoordSystems.size(); ++index) + { + if (this->CoordSystems[index].GetName() == name) + { + found = true; + return static_cast(index); + } + } + found = false; + return -1; +} + +vtkm::Id DataSet::FindCellSetIndex(const std::string& name, bool& found) const +{ + for (std::size_t index = 0; index < static_cast(this->GetNumberOfCellSets()); ++index) + { + if (this->CellSets[index].GetName() == name) + { + found = true; + return static_cast(index); + } + } + found = false; + return -1; +} + +} // namespace cont +} // namespace vtkm diff --git a/vtkm/cont/DataSet.h b/vtkm/cont/DataSet.h index 71e3a6da6..f4854b07d 100644 --- a/vtkm/cont/DataSet.h +++ b/vtkm/cont/DataSet.h @@ -20,6 +20,8 @@ #ifndef vtk_m_cont_DataSet_h #define vtk_m_cont_DataSet_h +#include + #include #include #include @@ -33,29 +35,17 @@ namespace vtkm namespace cont { -class DataSet +class VTKM_CONT_EXPORT DataSet { public: - VTKM_CONT - DataSet() {} + VTKM_CONT DataSet(); + + VTKM_CONT void Clear(); + + VTKM_CONT void AddField(const Field& field) { this->Fields.push_back(field); } VTKM_CONT - void Clear() - { - this->CoordSystems.clear(); - this->Fields.clear(); - this->CellSets.clear(); - } - - VTKM_CONT - void AddField(Field field) { this->Fields.push_back(field); } - - VTKM_CONT - const vtkm::cont::Field& GetField(vtkm::Id index) const - { - VTKM_ASSERT((index >= 0) && (index < this->GetNumberOfFields())); - return this->Fields[static_cast(index)]; - } + const vtkm::cont::Field& GetField(vtkm::Id index) const; VTKM_CONT bool HasField(const std::string& name, @@ -69,19 +59,7 @@ public: VTKM_CONT vtkm::Id GetFieldIndex( const std::string& name, - vtkm::cont::Field::AssociationEnum assoc = vtkm::cont::Field::ASSOC_ANY) const - { - bool found; - vtkm::Id index = this->FindFieldIndex(name, assoc, found); - if (found) - { - return index; - } - else - { - throw vtkm::cont::ErrorBadValue("No field with requested name: " + name); - } - } + vtkm::cont::Field::AssociationEnum assoc = vtkm::cont::Field::ASSOC_ANY) const; VTKM_CONT const vtkm::cont::Field& GetField( @@ -104,14 +82,13 @@ public: } VTKM_CONT - void AddCoordinateSystem(vtkm::cont::CoordinateSystem cs) { this->CoordSystems.push_back(cs); } + void AddCoordinateSystem(const vtkm::cont::CoordinateSystem& cs) + { + this->CoordSystems.push_back(cs); + } VTKM_CONT - const vtkm::cont::CoordinateSystem& GetCoordinateSystem(vtkm::Id index = 0) const - { - VTKM_ASSERT((index >= 0) && (index < this->GetNumberOfCoordinateSystems())); - return this->CoordSystems[static_cast(index)]; - } + const vtkm::cont::CoordinateSystem& GetCoordinateSystem(vtkm::Id index = 0) const; VTKM_CONT bool HasCoordinateSystem(const std::string& name) const @@ -122,19 +99,7 @@ public: } VTKM_CONT - vtkm::Id GetCoordinateSystemIndex(const std::string& name) const - { - bool found; - vtkm::Id index = this->FindCoordinateSystemIndex(name, found); - if (found) - { - return index; - } - else - { - throw vtkm::cont::ErrorBadValue("No coordinate system with requested name"); - } - } + vtkm::Id GetCoordinateSystemIndex(const std::string& name) const; VTKM_CONT const vtkm::cont::CoordinateSystem& GetCoordinateSystem(const std::string& name) const @@ -143,7 +108,7 @@ public: } VTKM_CONT - void AddCellSet(vtkm::cont::DynamicCellSet cellSet) { this->CellSets.push_back(cellSet); } + void AddCellSet(const vtkm::cont::DynamicCellSet& cellSet) { this->CellSets.push_back(cellSet); } template VTKM_CONT void AddCellSet(const CellSetType& cellSet) @@ -168,19 +133,7 @@ public: } VTKM_CONT - vtkm::Id GetCellSetIndex(const std::string& name) const - { - bool found; - vtkm::Id index = this->FindCellSetIndex(name, found); - if (found) - { - return index; - } - else - { - throw vtkm::cont::ErrorBadValue("No cell set with requested name"); - } - } + vtkm::Id GetCellSetIndex(const std::string& name) const; VTKM_CONT vtkm::cont::DynamicCellSet GetCellSet(const std::string& name) const @@ -207,27 +160,7 @@ public: } VTKM_CONT - void PrintSummary(std::ostream& out) const - { - out << "DataSet:\n"; - out << " CoordSystems[" << this->CoordSystems.size() << "]\n"; - for (std::size_t index = 0; index < this->CoordSystems.size(); index++) - { - this->CoordSystems[index].PrintSummary(out); - } - - out << " CellSets[" << this->GetNumberOfCellSets() << "]\n"; - for (vtkm::Id index = 0; index < this->GetNumberOfCellSets(); index++) - { - this->GetCellSet(index).PrintSummary(out); - } - - out << " Fields[" << this->GetNumberOfFields() << "]\n"; - for (vtkm::Id index = 0; index < this->GetNumberOfFields(); index++) - { - this->GetField(index).PrintSummary(out); - } - } + void PrintSummary(std::ostream& out) const; private: std::vector CoordSystems; @@ -237,51 +170,13 @@ private: VTKM_CONT vtkm::Id FindFieldIndex(const std::string& name, vtkm::cont::Field::AssociationEnum association, - bool& found) const - { - for (std::size_t index = 0; index < this->Fields.size(); ++index) - { - if ((association == vtkm::cont::Field::ASSOC_ANY || - association == this->Fields[index].GetAssociation()) && - this->Fields[index].GetName() == name) - { - found = true; - return static_cast(index); - } - } - found = false; - return -1; - } + bool& found) const; VTKM_CONT - vtkm::Id FindCoordinateSystemIndex(const std::string& name, bool& found) const - { - for (std::size_t index = 0; index < this->CoordSystems.size(); ++index) - { - if (this->CoordSystems[index].GetName() == name) - { - found = true; - return static_cast(index); - } - } - found = false; - return -1; - } + vtkm::Id FindCoordinateSystemIndex(const std::string& name, bool& found) const; VTKM_CONT - vtkm::Id FindCellSetIndex(const std::string& name, bool& found) const - { - for (std::size_t index = 0; index < static_cast(this->GetNumberOfCellSets()); ++index) - { - if (this->CellSets[index].GetName() == name) - { - found = true; - return static_cast(index); - } - } - found = false; - return -1; - } + vtkm::Id FindCellSetIndex(const std::string& name, bool& found) const; }; } // namespace cont diff --git a/vtkm/cont/DataSetFieldAdd.h b/vtkm/cont/DataSetFieldAdd.h index 341cf09a6..779b8f06d 100644 --- a/vtkm/cont/DataSetFieldAdd.h +++ b/vtkm/cont/DataSetFieldAdd.h @@ -57,7 +57,8 @@ public: const std::string& fieldName, const std::vector& field) { - dataSet.AddField(Field(fieldName, vtkm::cont::Field::ASSOC_POINTS, field)); + dataSet.AddField( + make_Field(fieldName, vtkm::cont::Field::ASSOC_POINTS, field, vtkm::CopyFlag::On)); } template @@ -66,7 +67,8 @@ public: const T* field, const vtkm::Id& n) { - dataSet.AddField(Field(fieldName, vtkm::cont::Field::ASSOC_POINTS, field, n)); + dataSet.AddField( + make_Field(fieldName, vtkm::cont::Field::ASSOC_POINTS, field, n, vtkm::CopyFlag::On)); } //Cell centered field @@ -94,7 +96,8 @@ public: const std::vector& field, const std::string& cellSetName) { - dataSet.AddField(Field(fieldName, vtkm::cont::Field::ASSOC_CELL_SET, cellSetName, field)); + dataSet.AddField(make_Field( + fieldName, vtkm::cont::Field::ASSOC_CELL_SET, cellSetName, field, vtkm::CopyFlag::On)); } template @@ -104,7 +107,8 @@ public: const vtkm::Id& n, const std::string& cellSetName) { - dataSet.AddField(Field(fieldName, vtkm::cont::Field::ASSOC_CELL_SET, cellSetName, field, n)); + dataSet.AddField(make_Field( + fieldName, vtkm::cont::Field::ASSOC_CELL_SET, cellSetName, field, n, vtkm::CopyFlag::On)); } VTKM_CONT diff --git a/vtkm/cont/DeviceAdapterAlgorithm.h b/vtkm/cont/DeviceAdapterAlgorithm.h index 3494c1d5d..ccb042055 100644 --- a/vtkm/cont/DeviceAdapterAlgorithm.h +++ b/vtkm/cont/DeviceAdapterAlgorithm.h @@ -223,14 +223,14 @@ struct DeviceAdapterAlgorithm VTKM_CONT static T ScanInclusive(const vtkm::cont::ArrayHandle& input, vtkm::cont::ArrayHandle& output); - /// \brief Streaming version of scan inclusive + /// \brief Streaming version of scan exclusive /// /// Computes a scan one block at a time. /// /// \return The total sum. /// template - VTKM_CONT static T StreamingScanInclusive(const vtkm::Id numBlocks, + VTKM_CONT static T StreamingScanExclusive(const vtkm::Id numBlocks, const vtkm::cont::ArrayHandle& input, vtkm::cont::ArrayHandle& output); @@ -282,18 +282,6 @@ struct DeviceAdapterAlgorithm const vtkm::cont::ArrayHandle& values, vtkm::cont::ArrayHandle& values_output); - /// \brief Streaming version of scan inclusive - /// - /// Computes a scan one block at a time. - /// - /// \return The total sum. - /// - template - VTKM_CONT static T StreamingScanInclusive(const vtkm::Id numBlocks, - const vtkm::cont::ArrayHandle& input, - vtkm::cont::ArrayHandle& output, - BinaryFunctor binary_functor); - /// \brief Compute an exclusive prefix sum operation on the input ArrayHandle. /// /// Computes an exclusive prefix sum operation on the \c input ArrayHandle, diff --git a/vtkm/cont/DynamicArrayHandle.h b/vtkm/cont/DynamicArrayHandle.h index bc600e648..bae63cd24 100644 --- a/vtkm/cont/DynamicArrayHandle.h +++ b/vtkm/cont/DynamicArrayHandle.h @@ -358,7 +358,7 @@ public: /// respectively. /// template - VTKM_CONT void CastAndCall(const Functor& f, Args&&...) const; + VTKM_CONT void CastAndCall(Functor&& f, Args&&...) const; /// \brief Create a new array of the same type as this array. /// @@ -414,15 +414,15 @@ struct DynamicArrayHandleTry } template - void operator()(std::pair&& p, Args&&... args) const + void operator()(brigand::list, Args&&... args) const { using storage = vtkm::cont::internal::Storage; using invalid = typename std::is_base_of::type; - this->run(std::forward(p), invalid{}, args...); + this->run(invalid{}, args...); } template - void run(std::pair&&, std::false_type, Functor&& f, bool& called, Args&&... args) const + void run(std::false_type, Functor&& f, bool& called, Args&&... args) const { if (!called) { @@ -437,7 +437,7 @@ struct DynamicArrayHandleTry } template - void run(std::pair&&, std::true_type, Args&&...) const + void run(std::true_type, Args&&...) const { } @@ -451,7 +451,7 @@ VTKM_CONT_EXPORT void ThrowCastAndCallException(PolymorphicArrayHandleContainerB template template -VTKM_CONT void DynamicArrayHandleBase::CastAndCall(const Functor& f, +VTKM_CONT void DynamicArrayHandleBase::CastAndCall(Functor&& f, Args&&... args) const { //For optimizations we should compile once the cross product for the default types @@ -460,8 +460,11 @@ VTKM_CONT void DynamicArrayHandleBase::CastAndCall(const bool called = false; auto* ptr = this->ArrayContainer.get(); - vtkm::ListForEach( - detail::DynamicArrayHandleTry(ptr), crossProduct{}, f, called, std::forward(args)...); + vtkm::ListForEach(detail::DynamicArrayHandleTry(ptr), + crossProduct{}, + std::forward(f), + called, + std::forward(args)...); if (!called) { // throw an exception diff --git a/vtkm/cont/DynamicCellSet.h b/vtkm/cont/DynamicCellSet.h index 3df58eaf3..aeb8ede00 100644 --- a/vtkm/cont/DynamicCellSet.h +++ b/vtkm/cont/DynamicCellSet.h @@ -228,7 +228,7 @@ public: /// behavior from \c CastAndCall. /// template - VTKM_CONT void CastAndCall(const Functor& f, Args&&...) const; + VTKM_CONT void CastAndCall(Functor&& f, Args&&...) const; /// \brief Create a new cell set of the same type as this cell set. /// @@ -302,11 +302,12 @@ struct DynamicCellSetTry template template -VTKM_CONT void DynamicCellSetBase::CastAndCall(const Functor& f, Args&&... args) const +VTKM_CONT void DynamicCellSetBase::CastAndCall(Functor&& f, Args&&... args) const { bool called = false; detail::DynamicCellSetTry tryCellSet(this->CellSetContainer.get()); - vtkm::ListForEach(tryCellSet, CellSetList{}, f, called, std::forward(args)...); + vtkm::ListForEach( + tryCellSet, CellSetList{}, std::forward(f), called, std::forward(args)...); if (!called) { throw vtkm::cont::ErrorBadValue("Could not find appropriate cast for cell set."); diff --git a/vtkm/cont/EnvironmentTracker.cxx b/vtkm/cont/EnvironmentTracker.cxx new file mode 100644 index 000000000..942ea4255 --- /dev/null +++ b/vtkm/cont/EnvironmentTracker.cxx @@ -0,0 +1,67 @@ +//============================================================================ +// Copyright (c) Kitware, Inc. +// All rights reserved. +// See LICENSE.txt for details. +// This software is distributed WITHOUT ANY WARRANTY; without even +// the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR +// PURPOSE. See the above copyright notice for more information. +// +// Copyright 2014 National Technology & Engineering Solutions of Sandia, LLC (NTESS). +// Copyright 2014 UT-Battelle, LLC. +// Copyright 2014 Los Alamos National Security. +// +// Under the terms of Contract DE-NA0003525 with NTESS, +// the U.S. Government retains certain rights in this software. +// +// Under the terms of Contract DE-AC52-06NA25396 with Los Alamos National +// Laboratory (LANL), the U.S. Government retains certain rights in +// this software. +//============================================================================ +#include + +#if defined(VTKM_ENABLE_MPI) +#include +#else +namespace diy +{ +namespace mpi +{ +class communicator +{ +}; +} +} +#endif + +namespace vtkm +{ +namespace cont +{ +#if defined(VTKM_ENABLE_MPI) +namespace internal +{ +static diy::mpi::communicator GlobalCommuncator(MPI_COMM_NULL); +} + +void EnvironmentTracker::SetCommunicator(const diy::mpi::communicator& comm) +{ + vtkm::cont::internal::GlobalCommuncator = comm; +} + +const diy::mpi::communicator& EnvironmentTracker::GetCommunicator() +{ + return vtkm::cont::internal::GlobalCommuncator; +} +#else +void EnvironmentTracker::SetCommunicator(const diy::mpi::communicator&) +{ +} + +const diy::mpi::communicator& EnvironmentTracker::GetCommunicator() +{ + static diy::mpi::communicator tmp; + return tmp; +} +#endif +} // namespace vtkm::cont +} // namespace vtkm diff --git a/vtkm/cont/EnvironmentTracker.h b/vtkm/cont/EnvironmentTracker.h new file mode 100644 index 000000000..a046f8c77 --- /dev/null +++ b/vtkm/cont/EnvironmentTracker.h @@ -0,0 +1,53 @@ +//============================================================================ +// Copyright (c) Kitware, Inc. +// All rights reserved. +// See LICENSE.txt for details. +// This software is distributed WITHOUT ANY WARRANTY; without even +// the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR +// PURPOSE. See the above copyright notice for more information. +// +// Copyright 2014 National Technology & Engineering Solutions of Sandia, LLC (NTESS). +// Copyright 2014 UT-Battelle, LLC. +// Copyright 2014 Los Alamos National Security. +// +// Under the terms of Contract DE-NA0003525 with NTESS, +// the U.S. Government retains certain rights in this software. +// +// Under the terms of Contract DE-AC52-06NA25396 with Los Alamos National +// Laboratory (LANL), the U.S. Government retains certain rights in +// this software. +//============================================================================ +#ifndef vtk_m_cont_EnvironmentTracker_h +#define vtk_m_cont_EnvironmentTracker_h + +#include +#include +#include +#include + +namespace diy +{ +namespace mpi +{ +class communicator; +} +} + +namespace vtkm +{ +namespace cont +{ +class VTKM_CONT_EXPORT EnvironmentTracker +{ +public: + VTKM_CONT + static void SetCommunicator(const diy::mpi::communicator& comm); + + VTKM_CONT + static const diy::mpi::communicator& GetCommunicator(); +}; +} +} + + +#endif // vtk_m_cont_EnvironmentTracker_h diff --git a/vtkm/cont/Field.h b/vtkm/cont/Field.h index 54ca6199e..a19a3a729 100644 --- a/vtkm/cont/Field.h +++ b/vtkm/cont/Field.h @@ -102,32 +102,6 @@ public: VTKM_ASSERT((this->Association == ASSOC_WHOLE_MESH) || (this->Association == ASSOC_POINTS)); } - template - VTKM_CONT Field(std::string name, AssociationEnum association, const std::vector& data) - : Name(name) - , Association(association) - , AssocCellSetName() - , AssocLogicalDim(-1) - , Range() - , ModifiedFlag(true) - { - VTKM_ASSERT((this->Association == ASSOC_WHOLE_MESH) || (this->Association == ASSOC_POINTS)); - this->CopyData(&data[0], static_cast(data.size())); - } - - template - VTKM_CONT Field(std::string name, AssociationEnum association, const T* data, vtkm::Id nvals) - : Name(name) - , Association(association) - , AssocCellSetName() - , AssocLogicalDim(-1) - , Range() - , ModifiedFlag(true) - { - VTKM_ASSERT((this->Association == ASSOC_WHOLE_MESH) || (this->Association == ASSOC_POINTS)); - this->CopyData(data, nvals); - } - /// constructors for cell set associations VTKM_CONT Field(std::string name, @@ -161,39 +135,6 @@ public: VTKM_ASSERT(this->Association == ASSOC_CELL_SET); } - template - VTKM_CONT Field(std::string name, - AssociationEnum association, - const std::string& cellSetName, - const std::vector& data) - : Name(name) - , Association(association) - , AssocCellSetName(cellSetName) - , AssocLogicalDim(-1) - , Range() - , ModifiedFlag(true) - { - VTKM_ASSERT(this->Association == ASSOC_CELL_SET); - this->CopyData(&data[0], static_cast(data.size())); - } - - template - VTKM_CONT Field(std::string name, - AssociationEnum association, - const std::string& cellSetName, - const T* data, - vtkm::Id nvals) - : Name(name) - , Association(association) - , AssocCellSetName(cellSetName) - , AssocLogicalDim(-1) - , Range() - , ModifiedFlag(true) - { - VTKM_ASSERT(this->Association == ASSOC_CELL_SET); - this->CopyData(data, nvals); - } - /// constructors for logical dimension associations VTKM_CONT Field(std::string name, @@ -226,37 +167,6 @@ public: VTKM_ASSERT(this->Association == ASSOC_LOGICAL_DIM); } - template - VTKM_CONT Field(std::string name, - AssociationEnum association, - vtkm::IdComponent logicalDim, - const std::vector& data) - : Name(name) - , Association(association) - , AssocLogicalDim(logicalDim) - , Range() - , ModifiedFlag(true) - { - VTKM_ASSERT(this->Association == ASSOC_LOGICAL_DIM); - this->CopyData(&data[0], static_cast(data.size())); - } - - template - VTKM_CONT Field(std::string name, - AssociationEnum association, - vtkm::IdComponent logicalDim, - const T* data, - vtkm::Id nvals) - : Name(name) - , Association(association) - , AssocLogicalDim(logicalDim) - , Range() - , ModifiedFlag(true) - { - VTKM_ASSERT(this->Association == ASSOC_LOGICAL_DIM); - CopyData(data, nvals); - } - VTKM_CONT Field() : Name() @@ -356,17 +266,7 @@ public: template VTKM_CONT void CopyData(const T* ptr, vtkm::Id nvals) { - //allocate main memory using an array handle - vtkm::cont::ArrayHandle tmp; - tmp.Allocate(nvals); - - //copy into the memory owned by the array handle - std::copy(ptr, - ptr + static_cast(nvals), - vtkm::cont::ArrayPortalToIteratorBegin(tmp.GetPortalControl())); - - //assign to the dynamic array handle - this->Data = tmp; + this->Data = vtkm::cont::make_ArrayHandle(ptr, nvals, true); this->ModifiedFlag = true; } @@ -402,11 +302,78 @@ private: }; template -void CastAndCall(const vtkm::cont::Field& field, const Functor& f, Args&&... args) +void CastAndCall(const vtkm::cont::Field& field, Functor&& f, Args&&... args) { - field.GetData().CastAndCall(f, std::forward(args)...); + field.GetData().CastAndCall(std::forward(f), std::forward(args)...); } +//@{ +/// Convinience functions to build fields from C style arrays and std::vector +template +vtkm::cont::Field make_Field(std::string name, + Field::AssociationEnum association, + const T* data, + vtkm::Id size, + vtkm::CopyFlag copy = vtkm::CopyFlag::Off) +{ + return vtkm::cont::Field(name, association, vtkm::cont::make_ArrayHandle(data, size, copy)); +} + +template +vtkm::cont::Field make_Field(std::string name, + Field::AssociationEnum association, + const std::vector& data, + vtkm::CopyFlag copy = vtkm::CopyFlag::Off) +{ + return vtkm::cont::Field(name, association, vtkm::cont::make_ArrayHandle(data, copy)); +} + +template +vtkm::cont::Field make_Field(std::string name, + Field::AssociationEnum association, + const std::string& cellSetName, + const T* data, + vtkm::Id size, + vtkm::CopyFlag copy = vtkm::CopyFlag::Off) +{ + return vtkm::cont::Field( + name, association, cellSetName, vtkm::cont::make_ArrayHandle(data, size, copy)); +} + +template +vtkm::cont::Field make_Field(std::string name, + Field::AssociationEnum association, + const std::string& cellSetName, + const std::vector& data, + vtkm::CopyFlag copy = vtkm::CopyFlag::Off) +{ + return vtkm::cont::Field( + name, association, cellSetName, vtkm::cont::make_ArrayHandle(data, copy)); +} + +template +vtkm::cont::Field make_Field(std::string name, + Field::AssociationEnum association, + vtkm::IdComponent logicalDim, + const T* data, + vtkm::Id size, + vtkm::CopyFlag copy = vtkm::CopyFlag::Off) +{ + return vtkm::cont::Field( + name, association, logicalDim, vtkm::cont::make_ArrayHandle(data, size, copy)); +} + +template +vtkm::cont::Field make_Field(std::string name, + Field::AssociationEnum association, + vtkm::IdComponent logicalDim, + const std::vector& data, + vtkm::CopyFlag copy = vtkm::CopyFlag::Off) +{ + return vtkm::cont::Field(name, association, logicalDim, vtkm::cont::make_ArrayHandle(data, copy)); +} +//@} + namespace internal { diff --git a/vtkm/cont/MultiBlock.cxx b/vtkm/cont/MultiBlock.cxx index 0064c410b..9d71bfb50 100644 --- a/vtkm/cont/MultiBlock.cxx +++ b/vtkm/cont/MultiBlock.cxx @@ -19,13 +19,142 @@ //============================================================================ #include +#include #include #include #include #include +#include #include #include #include + +#if defined(VTKM_ENABLE_MPI) +#include + +namespace vtkm +{ +namespace cont +{ +namespace detail +{ +template +VTKM_CONT std::vector CopyArrayPortalToVector( + const PortalType& portal) +{ + using ValueType = typename PortalType::ValueType; + std::vector result(portal.GetNumberOfValues()); + vtkm::cont::ArrayPortalToIterators iterators(portal); + std::copy(iterators.GetBegin(), iterators.GetEnd(), result.begin()); + return result; +} +} +} +} + +namespace std +{ + +namespace detail +{ + +template +struct MPIPlus +{ + MPIPlus() + { + this->OpPtr = std::shared_ptr(new MPI_Op(MPI_NO_OP), [](MPI_Op* ptr) { + MPI_Op_free(ptr); + delete ptr; + }); + + MPI_Op_create( + [](void* a, void* b, int* len, MPI_Datatype*) { + T* ba = reinterpret_cast(a); + T* bb = reinterpret_cast(b); + for (int cc = 0; cc < (*len) / ElementSize; ++cc) + { + bb[cc] = ba[cc] + bb[cc]; + } + }, + 1, + this->OpPtr.get()); + } + ~MPIPlus() {} + operator MPI_Op() const { return *this->OpPtr.get(); } +private: + std::shared_ptr OpPtr; +}; + +} // std::detail + +template <> +struct plus +{ + MPI_Op get_mpi_op() const { return this->Op; } + vtkm::Bounds operator()(const vtkm::Bounds& lhs, const vtkm::Bounds& rhs) const + { + return lhs + rhs; + } + +private: + std::detail::MPIPlus Op; +}; + +template <> +struct plus +{ + MPI_Op get_mpi_op() const { return this->Op; } + vtkm::Range operator()(const vtkm::Range& lhs, const vtkm::Range& rhs) const { return lhs + rhs; } + +private: + std::detail::MPIPlus Op; +}; +} + +namespace diy +{ +namespace mpi +{ +namespace detail +{ +template <> +struct mpi_datatype +{ + static MPI_Datatype datatype() { return get_mpi_datatype(); } + static const void* address(const vtkm::Bounds& x) { return &x; } + static void* address(vtkm::Bounds& x) { return &x; } + static int count(const vtkm::Bounds&) { return 6; } +}; + +template <> +struct mpi_op> +{ + static MPI_Op get(const std::plus& op) { return op.get_mpi_op(); } +}; + +template <> +struct mpi_datatype +{ + static MPI_Datatype datatype() { return get_mpi_datatype(); } + static const void* address(const vtkm::Range& x) { return &x; } + static void* address(vtkm::Range& x) { return &x; } + static int count(const vtkm::Range&) { return 2; } +}; + +template <> +struct mpi_op> +{ + static MPI_Op get(const std::plus& op) { return op.get_mpi_op(); } +}; + +} // diy::mpi::detail +} // diy::mpi +} // diy + + +#endif + namespace vtkm { namespace cont @@ -34,25 +163,25 @@ namespace cont VTKM_CONT MultiBlock::MultiBlock(const vtkm::cont::DataSet& ds) { - this->blocks.insert(blocks.end(), ds); + this->Blocks.insert(this->Blocks.end(), ds); } VTKM_CONT MultiBlock::MultiBlock(const vtkm::cont::MultiBlock& src) { - this->blocks = src.GetBlocks(); + this->Blocks = src.GetBlocks(); } VTKM_CONT MultiBlock::MultiBlock(const std::vector& mblocks) { - this->blocks = mblocks; + this->Blocks = mblocks; } VTKM_CONT MultiBlock::MultiBlock(vtkm::Id size) { - this->blocks.reserve(static_cast(size)); + this->Blocks.reserve(static_cast(size)); } VTKM_CONT @@ -68,7 +197,7 @@ MultiBlock::~MultiBlock() VTKM_CONT MultiBlock& MultiBlock::operator=(const vtkm::cont::MultiBlock& src) { - this->blocks = src.GetBlocks(); + this->Blocks = src.GetBlocks(); return *this; } @@ -76,46 +205,68 @@ VTKM_CONT vtkm::cont::Field MultiBlock::GetField(const std::string& field_name, const int& block_index) { assert(block_index >= 0); - assert(static_cast(block_index) < blocks.size()); - return blocks[static_cast(block_index)].GetField(field_name); + assert(static_cast(block_index) < this->Blocks.size()); + return this->Blocks[static_cast(block_index)].GetField(field_name); } VTKM_CONT vtkm::Id MultiBlock::GetNumberOfBlocks() const { - return static_cast(this->blocks.size()); + return static_cast(this->Blocks.size()); +} + +VTKM_CONT +vtkm::Id MultiBlock::GetGlobalNumberOfBlocks() const +{ +#if defined(VTKM_ENABLE_MPI) + auto world = vtkm::cont::EnvironmentTracker::GetCommunicator(); + const auto local_count = this->GetNumberOfBlocks(); + + diy::Master master(world, 1, -1); + int block_not_used = 1; + master.add(world.rank(), &block_not_used, new diy::Link()); + // empty link since we're only using collectives. + master.foreach ([=](void*, const diy::Master::ProxyWithLink& cp) { + cp.all_reduce(local_count, std::plus()); + }); + master.process_collectives(); + vtkm::Id global_count = master.proxy(0).get(); + return global_count; +#else + return this->GetNumberOfBlocks(); +#endif } VTKM_CONT const vtkm::cont::DataSet& MultiBlock::GetBlock(vtkm::Id blockId) const { - return this->blocks[static_cast(blockId)]; + return this->Blocks[static_cast(blockId)]; } VTKM_CONT const std::vector& MultiBlock::GetBlocks() const { - return this->blocks; + return this->Blocks; } VTKM_CONT void MultiBlock::AddBlock(vtkm::cont::DataSet& ds) { - this->blocks.insert(blocks.end(), ds); + this->Blocks.insert(this->Blocks.end(), ds); return; } void MultiBlock::AddBlocks(std::vector& mblocks) { - this->blocks.insert(blocks.end(), mblocks.begin(), mblocks.end()); + this->Blocks.insert(this->Blocks.end(), mblocks.begin(), mblocks.end()); return; } VTKM_CONT void MultiBlock::InsertBlock(vtkm::Id index, vtkm::cont::DataSet& ds) { - if (index <= static_cast(blocks.size())) - this->blocks.insert(blocks.begin() + index, ds); + if (index <= static_cast(this->Blocks.size())) + this->Blocks.insert(this->Blocks.begin() + index, ds); else { std::string msg = "invalid insert position\n "; @@ -126,8 +277,8 @@ void MultiBlock::InsertBlock(vtkm::Id index, vtkm::cont::DataSet& ds) VTKM_CONT void MultiBlock::ReplaceBlock(vtkm::Id index, vtkm::cont::DataSet& ds) { - if (index < static_cast(blocks.size())) - this->blocks.at(static_cast(index)) = ds; + if (index < static_cast(this->Blocks.size())) + this->Blocks.at(static_cast(index)) = ds; else { std::string msg = "invalid replace position\n "; @@ -158,8 +309,32 @@ VTKM_CONT vtkm::Bounds MultiBlock::GetBounds(vtkm::Id coordinate_system_index, VTKM_IS_LIST_TAG(TypeList); VTKM_IS_LIST_TAG(StorageList); +#if defined(VTKM_ENABLE_MPI) + auto world = vtkm::cont::EnvironmentTracker::GetCommunicator(); + //const auto global_num_blocks = this->GetGlobalNumberOfBlocks(); + + const auto num_blocks = this->GetNumberOfBlocks(); + + diy::Master master(world, 1, -1); + for (vtkm::Id cc = 0; cc < num_blocks; ++cc) + { + int gid = cc * world.size() + world.rank(); + master.add(gid, const_cast(&this->Blocks[cc]), new diy::Link()); + } + + master.foreach ([&](const vtkm::cont::DataSet* block, const diy::Master::ProxyWithLink& cp) { + auto coords = block->GetCoordinateSystem(coordinate_system_index); + const vtkm::Bounds bounds = coords.GetBounds(TypeList(), StorageList()); + cp.all_reduce(bounds, std::plus()); + }); + + master.process_collectives(); + auto bounds = master.proxy(0).get(); + return bounds; + +#else const vtkm::Id index = coordinate_system_index; - const size_t num_blocks = blocks.size(); + const size_t num_blocks = this->Blocks.size(); vtkm::Bounds bounds; for (size_t i = 0; i < num_blocks; ++i) @@ -167,8 +342,8 @@ VTKM_CONT vtkm::Bounds MultiBlock::GetBounds(vtkm::Id coordinate_system_index, vtkm::Bounds block_bounds = this->GetBlockBounds(i, index, TypeList(), StorageList()); bounds.Include(block_bounds); } - return bounds; +#endif } VTKM_CONT @@ -206,7 +381,7 @@ VTKM_CONT vtkm::Bounds MultiBlock::GetBlockBounds(const std::size_t& block_index vtkm::cont::CoordinateSystem coords; try { - coords = blocks[block_index].GetCoordinateSystem(index); + coords = this->Blocks[block_index].GetCoordinateSystem(index); } catch (const vtkm::cont::Error& error) { @@ -241,8 +416,8 @@ VTKM_CONT vtkm::cont::ArrayHandle MultiBlock::GetGlobalRange(const VTKM_IS_LIST_TAG(TypeList); VTKM_IS_LIST_TAG(StorageList); - assert(blocks.size() > 0); - vtkm::cont::Field field = blocks.at(0).GetField(index); + assert(this->Blocks.size() > 0); + vtkm::cont::Field field = this->Blocks.at(0).GetField(index); std::string field_name = field.GetName(); return this->GetGlobalRange(field_name, TypeList(), StorageList()); } @@ -267,21 +442,86 @@ template VTKM_CONT vtkm::cont::ArrayHandle MultiBlock::GetGlobalRange(const std::string& field_name, TypeList, StorageList) const { +#if defined(VTKM_ENABLE_MPI) + auto world = vtkm::cont::EnvironmentTracker::GetCommunicator(); + const auto num_blocks = this->GetNumberOfBlocks(); + + diy::Master master(world); + for (vtkm::Id cc = 0; cc < num_blocks; ++cc) + { + int gid = cc * world.size() + world.rank(); + master.add(gid, const_cast(&this->Blocks[cc]), new diy::Link()); + } + + // collect info about number of components in the field. + master.foreach ([&](const vtkm::cont::DataSet* dataset, const diy::Master::ProxyWithLink& cp) { + if (dataset->HasField(field_name)) + { + auto field = dataset->GetField(field_name); + const vtkm::cont::ArrayHandle range = field.GetRange(TypeList(), StorageList()); + vtkm::Id components = range.GetPortalConstControl().GetNumberOfValues(); + cp.all_reduce(components, diy::mpi::maximum()); + } + }); + master.process_collectives(); + + const vtkm::Id components = master.size() ? master.proxy(0).read() : 0; + + // clear all collectives. + master.foreach ([&](const vtkm::cont::DataSet*, const diy::Master::ProxyWithLink& cp) { + cp.collectives()->clear(); + }); + + master.foreach ([&](const vtkm::cont::DataSet* dataset, const diy::Master::ProxyWithLink& cp) { + if (dataset->HasField(field_name)) + { + auto field = dataset->GetField(field_name); + const vtkm::cont::ArrayHandle range = field.GetRange(TypeList(), StorageList()); + const auto v_range = + vtkm::cont::detail::CopyArrayPortalToVector(range.GetPortalConstControl()); + for (const vtkm::Range& r : v_range) + { + cp.all_reduce(r, std::plus()); + } + // if current block has less that the max number of components, just add invalid ranges for the rest. + for (vtkm::Id cc = static_cast(v_range.size()); cc < components; ++cc) + { + cp.all_reduce(vtkm::Range(), std::plus()); + } + } + }); + master.process_collectives(); + std::vector ranges(components); + // FIXME: is master.size() == 0 i.e. there are no blocks on the current rank, + // this method won't return valid range. + if (master.size() > 0) + { + for (vtkm::Id cc = 0; cc < components; ++cc) + { + ranges[cc] = master.proxy(0).get(); + } + } + + vtkm::cont::ArrayHandle tmprange = vtkm::cont::make_ArrayHandle(ranges); + vtkm::cont::ArrayHandle range; + vtkm::cont::ArrayCopy(vtkm::cont::make_ArrayHandle(ranges), range); + return range; +#else bool valid_field = true; - const size_t num_blocks = blocks.size(); + const size_t num_blocks = this->Blocks.size(); vtkm::cont::ArrayHandle range; vtkm::Id num_components = 0; for (size_t i = 0; i < num_blocks; ++i) { - if (!blocks[i].HasField(field_name)) + if (!this->Blocks[i].HasField(field_name)) { valid_field = false; break; } - const vtkm::cont::Field& field = blocks[i].GetField(field_name); + const vtkm::cont::Field& field = this->Blocks[i].GetField(field_name); vtkm::cont::ArrayHandle sub_range = field.GetRange(TypeList(), StorageList()); vtkm::cont::ArrayHandle::PortalConstControl sub_range_control = @@ -324,6 +564,7 @@ MultiBlock::GetGlobalRange(const std::string& field_name, TypeList, StorageList) } return range; +#endif } VTKM_CONT @@ -332,10 +573,10 @@ void MultiBlock::PrintSummary(std::ostream& stream) const stream << "block " << "\n"; - for (size_t block_index = 0; block_index < blocks.size(); ++block_index) + for (size_t block_index = 0; block_index < this->Blocks.size(); ++block_index) { stream << "block " << block_index << "\n"; - blocks[block_index].PrintSummary(stream); + this->Blocks[block_index].PrintSummary(stream); } } } diff --git a/vtkm/cont/MultiBlock.h b/vtkm/cont/MultiBlock.h index 5a78e4fda..dad32672c 100644 --- a/vtkm/cont/MultiBlock.h +++ b/vtkm/cont/MultiBlock.h @@ -64,6 +64,13 @@ public: VTKM_CONT vtkm::Id GetNumberOfBlocks() const; + /// Returns the number of blocks across all ranks. For non-MPI builds, this + /// will be same as `GetNumberOfBlocks()`. + /// This method is not thread-safe and may involve global communication across + /// all ranks in distributed environments with MPI. + VTKM_CONT + vtkm::Id GetGlobalNumberOfBlocks() const; + VTKM_CONT const vtkm::cont::DataSet& GetBlock(vtkm::Id blockId) const; @@ -105,7 +112,11 @@ public: vtkm::Id coordinate_system_index, TypeList, StorageList) const; - /// get the unified range of the same feild within all contained DataSet + + //@{ + /// Get the unified range of the same field within all contained DataSet. + /// These methods are not thread-safe and may involve global communication + /// across all ranks in distributed environments with MPI. VTKM_CONT vtkm::cont::ArrayHandle GetGlobalRange(const std::string& field_name) const; @@ -128,12 +139,13 @@ public: VTKM_CONT vtkm::cont::ArrayHandle GetGlobalRange(const int& index, TypeList, StorageList) const; + //@} VTKM_CONT void PrintSummary(std::ostream& stream) const; private: - std::vector blocks; + std::vector Blocks; }; } } // namespace vtkm::cont diff --git a/vtkm/cont/cuda/internal/CudaAllocator.cu b/vtkm/cont/cuda/internal/CudaAllocator.cu index aa7c35ea2..a1c4fe84c 100644 --- a/vtkm/cont/cuda/internal/CudaAllocator.cu +++ b/vtkm/cont/cuda/internal/CudaAllocator.cu @@ -33,6 +33,10 @@ static bool IsInitialized = false; // True if all devices support concurrent pagable managed memory. static bool ManagedMemorySupported = false; + +// Avoid overhead of cudaMemAdvise and cudaMemPrefetchAsync for small buffers. +// This value should be > 0 or else these functions will error out. +static std::size_t Threshold = 1 << 20; } namespace vtkm @@ -94,6 +98,12 @@ bool CudaAllocator::IsManagedPointer(const void* ptr) void* CudaAllocator::Allocate(std::size_t numBytes) { CudaAllocator::Initialize(); + // When numBytes is zero cudaMallocManaged returns an error and the behavior + // of cudaMalloc is not documented. Just return nullptr. + if (numBytes == 0) + { + return nullptr; + } void* ptr = nullptr; if (ManagedMemorySupported) @@ -115,7 +125,7 @@ void CudaAllocator::Free(void* ptr) void CudaAllocator::PrepareForControl(const void* ptr, std::size_t numBytes) { - if (IsManagedPointer(ptr)) + if (IsManagedPointer(ptr) && numBytes >= Threshold) { #if CUDART_VERSION >= 8000 // TODO these hints need to be benchmarked and adjusted once we start @@ -128,7 +138,7 @@ void CudaAllocator::PrepareForControl(const void* ptr, std::size_t numBytes) void CudaAllocator::PrepareForInput(const void* ptr, std::size_t numBytes) { - if (IsManagedPointer(ptr)) + if (IsManagedPointer(ptr) && numBytes >= Threshold) { #if CUDART_VERSION >= 8000 int dev; @@ -143,7 +153,7 @@ void CudaAllocator::PrepareForInput(const void* ptr, std::size_t numBytes) void CudaAllocator::PrepareForOutput(const void* ptr, std::size_t numBytes) { - if (IsManagedPointer(ptr)) + if (IsManagedPointer(ptr) && numBytes >= Threshold) { #if CUDART_VERSION >= 8000 int dev; @@ -158,7 +168,7 @@ void CudaAllocator::PrepareForOutput(const void* ptr, std::size_t numBytes) void CudaAllocator::PrepareForInPlace(const void* ptr, std::size_t numBytes) { - if (IsManagedPointer(ptr)) + if (IsManagedPointer(ptr) && numBytes >= Threshold) { #if CUDART_VERSION >= 8000 int dev; diff --git a/vtkm/cont/internal/DynamicTransform.h b/vtkm/cont/internal/DynamicTransform.h index 1195723c0..61431f484 100644 --- a/vtkm/cont/internal/DynamicTransform.h +++ b/vtkm/cont/internal/DynamicTransform.h @@ -48,28 +48,28 @@ class CellSetPermutation; /// DynamicObject's CastAndCall, but specializations of this function exist for /// other classes (e.g. Field, CoordinateSystem, ArrayHandle). template -void CastAndCall(const DynamicObject& dynamicObject, const Functor& f, Args&&... args) +void CastAndCall(const DynamicObject& dynamicObject, Functor&& f, Args&&... args) { - dynamicObject.CastAndCall(f, std::forward(args)...); + dynamicObject.CastAndCall(std::forward(f), std::forward(args)...); } /// A specialization of CastAndCall for basic CoordinateSystem to make /// it be treated just like any other dynamic object // actually implemented in vtkm/cont/CoordinateSystem template -void CastAndCall(const CoordinateSystem& coords, const Functor& f, Args&&... args); +void CastAndCall(const CoordinateSystem& coords, Functor&& f, Args&&... args); /// A specialization of CastAndCall for basic Field to make /// it be treated just like any other dynamic object // actually implemented in vtkm/cont/Field template -void CastAndCall(const vtkm::cont::Field& field, const Functor& f, Args&&... args); +void CastAndCall(const vtkm::cont::Field& field, Functor&& f, Args&&... args); /// A specialization of CastAndCall for basic ArrayHandle types, /// Since the type is already known no deduction is needed. /// This specialization is used to simplify numerous worklet algorithms template -void CastAndCall(const vtkm::cont::ArrayHandle& handle, const Functor& f, Args&&... args) +void CastAndCall(const vtkm::cont::ArrayHandle& handle, Functor&& f, Args&&... args) { f(handle, std::forward(args)...); } @@ -78,9 +78,7 @@ void CastAndCall(const vtkm::cont::ArrayHandle& handle, const Functor& f, /// Since the type is already known no deduction is needed. /// This specialization is used to simplify numerous worklet algorithms template -void CastAndCall(const vtkm::cont::CellSetStructured& cellset, - const Functor& f, - Args&&... args) +void CastAndCall(const vtkm::cont::CellSetStructured& cellset, Functor&& f, Args&&... args) { f(cellset, std::forward(args)...); } @@ -90,7 +88,7 @@ void CastAndCall(const vtkm::cont::CellSetStructured& cellset, /// This specialization is used to simplify numerous worklet algorithms template void CastAndCall(const vtkm::cont::CellSetSingleType& cellset, - const Functor& f, + Functor&& f, Args&&... args) { f(cellset, std::forward(args)...); @@ -101,7 +99,7 @@ void CastAndCall(const vtkm::cont::CellSetSingleType& ce /// This specialization is used to simplify numerous worklet algorithms template void CastAndCall(const vtkm::cont::CellSetExplicit& cellset, - const Functor& f, + Functor&& f, Args&&... args) { f(cellset, std::forward(args)...); @@ -112,7 +110,7 @@ void CastAndCall(const vtkm::cont::CellSetExplicit& cellset, /// This specialization is used to simplify numerous worklet algorithms template void CastAndCall(const vtkm::cont::CellSetPermutation& cellset, - const Functor& f, + Functor&& f, Args&&... args) { f(cellset, std::forward(args)...); diff --git a/vtkm/cont/testing/CMakeLists.txt b/vtkm/cont/testing/CMakeLists.txt index 8af665b81..22968cb2c 100644 --- a/vtkm/cont/testing/CMakeLists.txt +++ b/vtkm/cont/testing/CMakeLists.txt @@ -37,6 +37,7 @@ set(headers vtkm_declare_headers(${headers}) set(unit_tests + UnitTestAlgorithm.cxx UnitTestArrayCopy.cxx UnitTestArrayHandleCartesianProduct.cxx UnitTestArrayHandleCompositeVector.cxx @@ -65,7 +66,7 @@ set(unit_tests UnitTestDeviceAdapterAlgorithmGeneral.cxx UnitTestDynamicArrayHandle.cxx UnitTestDynamicCellSet.cxx - UnitTestMultiBlock.cxx + UnitTestMultiBlock.cxx,MPI UnitTestRuntimeDeviceInformation.cxx UnitTestStorageBasic.cxx UnitTestStorageImplicit.cxx diff --git a/vtkm/cont/testing/MakeTestDataSet.h b/vtkm/cont/testing/MakeTestDataSet.h index c11e6f842..8ba9caf46 100644 --- a/vtkm/cont/testing/MakeTestDataSet.h +++ b/vtkm/cont/testing/MakeTestDataSet.h @@ -53,6 +53,7 @@ public: // 3D uniform datasets. vtkm::cont::DataSet Make3DUniformDataSet0(); vtkm::cont::DataSet Make3DUniformDataSet1(); + vtkm::cont::DataSet Make3DUniformDataSet2(); vtkm::cont::DataSet Make3DRegularDataSet0(); vtkm::cont::DataSet Make3DRegularDataSet1(); @@ -245,6 +246,32 @@ inline vtkm::cont::DataSet MakeTestDataSet::Make3DUniformDataSet1() return dataSet; } +inline vtkm::cont::DataSet MakeTestDataSet::Make3DUniformDataSet2() +{ + const vtkm::Id base_size = 256; + vtkm::cont::DataSetBuilderUniform dsb; + vtkm::Id3 dimensions(base_size, base_size, base_size); + vtkm::cont::DataSet dataSet = dsb.Create(dimensions); + + vtkm::cont::DataSetFieldAdd dsf; + const vtkm::Id nVerts = base_size * base_size * base_size; + vtkm::Float32* pointvar = new vtkm::Float32[nVerts]; + + for (vtkm::Int32 z = 0; z < base_size; ++z) + for (vtkm::Int32 y = 0; y < base_size; ++y) + for (vtkm::Int32 x = 0; x < base_size; ++x) + { + vtkm::Int32 index = z * base_size * base_size + y * base_size + x; + pointvar[index] = vtkm::Sqrt(vtkm::Float32(x * x + y * y + z * z)); + } + + dsf.AddPointField(dataSet, "pointvar", pointvar, nVerts); + + delete[] pointvar; + + return dataSet; +} + inline vtkm::cont::DataSet MakeTestDataSet::Make2DRectilinearDataSet0() { vtkm::cont::DataSetBuilderRectilinear dsb; @@ -287,11 +314,13 @@ inline vtkm::cont::DataSet MakeTestDataSet::Make3DRegularDataSet0() dataSet.AddCoordinateSystem(vtkm::cont::CoordinateSystem("coordinates", coordinates)); //Set point scalar - dataSet.AddField(Field("pointvar", vtkm::cont::Field::ASSOC_POINTS, vars, nVerts)); + dataSet.AddField( + make_Field("pointvar", vtkm::cont::Field::ASSOC_POINTS, vars, nVerts, vtkm::CopyFlag::On)); //Set cell scalar vtkm::Float32 cellvar[4] = { 100.1f, 100.2f, 100.3f, 100.4f }; - dataSet.AddField(Field("cellvar", vtkm::cont::Field::ASSOC_CELL_SET, "cells", cellvar, 4)); + dataSet.AddField(make_Field( + "cellvar", vtkm::cont::Field::ASSOC_CELL_SET, "cells", cellvar, 4, vtkm::CopyFlag::On)); static const vtkm::IdComponent dim = 3; vtkm::cont::CellSetStructured cellSet("cells"); @@ -312,11 +341,13 @@ inline vtkm::cont::DataSet MakeTestDataSet::Make3DRegularDataSet1() dataSet.AddCoordinateSystem(vtkm::cont::CoordinateSystem("coordinates", coordinates)); //Set point scalar - dataSet.AddField(Field("pointvar", vtkm::cont::Field::ASSOC_POINTS, vars, nVerts)); + dataSet.AddField( + make_Field("pointvar", vtkm::cont::Field::ASSOC_POINTS, vars, nVerts, vtkm::CopyFlag::On)); //Set cell scalar vtkm::Float32 cellvar[1] = { 100.1f }; - dataSet.AddField(Field("cellvar", vtkm::cont::Field::ASSOC_CELL_SET, "cells", cellvar, 1)); + dataSet.AddField(make_Field( + "cellvar", vtkm::cont::Field::ASSOC_CELL_SET, "cells", cellvar, 1, vtkm::CopyFlag::On)); static const vtkm::IdComponent dim = 3; vtkm::cont::CellSetStructured cellSet("cells"); @@ -556,7 +587,8 @@ inline vtkm::cont::DataSet MakeTestDataSet::Make3DExplicitDataSet1() CoordType(2, 2, 0) }; vtkm::Float32 vars[nVerts] = { 10.1f, 20.1f, 30.2f, 40.2f, 50.3f }; - dataSet.AddCoordinateSystem(vtkm::cont::CoordinateSystem("coordinates", coordinates, nVerts)); + dataSet.AddCoordinateSystem( + vtkm::cont::make_CoordinateSystem("coordinates", coordinates, nVerts, vtkm::CopyFlag::On)); vtkm::cont::CellSetExplicit<> cellSet("cells"); cellSet.PrepareToAddCells(2, 7); cellSet.AddCell(vtkm::CELL_SHAPE_TRIANGLE, 3, make_Vec(0, 1, 2)); @@ -565,11 +597,13 @@ inline vtkm::cont::DataSet MakeTestDataSet::Make3DExplicitDataSet1() dataSet.AddCellSet(cellSet); //Set point scalar - dataSet.AddField(Field("pointvar", vtkm::cont::Field::ASSOC_POINTS, vars, nVerts)); + dataSet.AddField( + make_Field("pointvar", vtkm::cont::Field::ASSOC_POINTS, vars, nVerts, vtkm::CopyFlag::On)); //Set cell scalar vtkm::Float32 cellvar[2] = { 100.1f, 100.2f }; - dataSet.AddField(Field("cellvar", vtkm::cont::Field::ASSOC_CELL_SET, "cells", cellvar, 2)); + dataSet.AddField(make_Field( + "cellvar", vtkm::cont::Field::ASSOC_CELL_SET, "cells", cellvar, 2, vtkm::CopyFlag::On)); return dataSet; } @@ -592,14 +626,17 @@ inline vtkm::cont::DataSet MakeTestDataSet::Make3DExplicitDataSet2() }; vtkm::Float32 vars[nVerts] = { 10.1f, 20.1f, 30.2f, 40.2f, 50.3f, 60.2f, 70.2f, 80.3f }; - dataSet.AddCoordinateSystem(vtkm::cont::CoordinateSystem("coordinates", coordinates, nVerts)); + dataSet.AddCoordinateSystem( + vtkm::cont::make_CoordinateSystem("coordinates", coordinates, nVerts, vtkm::CopyFlag::On)); //Set point scalar - dataSet.AddField(Field("pointvar", vtkm::cont::Field::ASSOC_POINTS, vars, nVerts)); + dataSet.AddField( + make_Field("pointvar", vtkm::cont::Field::ASSOC_POINTS, vars, nVerts, vtkm::CopyFlag::On)); //Set cell scalar vtkm::Float32 cellvar[2] = { 100.1f }; - dataSet.AddField(Field("cellvar", vtkm::cont::Field::ASSOC_CELL_SET, "cells", cellvar, 1)); + dataSet.AddField(make_Field( + "cellvar", vtkm::cont::Field::ASSOC_CELL_SET, "cells", cellvar, 1, vtkm::CopyFlag::On)); vtkm::cont::CellSetExplicit<> cellSet("cells"); vtkm::Vec ids; @@ -645,14 +682,17 @@ inline vtkm::cont::DataSet MakeTestDataSet::Make3DExplicitDataSet4() vtkm::Float32 vars[nVerts] = { 10.1f, 20.1f, 30.2f, 40.2f, 50.3f, 60.2f, 70.2f, 80.3f, 90.f, 10.f, 11.f, 12.f }; - dataSet.AddCoordinateSystem(vtkm::cont::CoordinateSystem("coordinates", coordinates, nVerts)); + dataSet.AddCoordinateSystem( + vtkm::cont::make_CoordinateSystem("coordinates", coordinates, nVerts, vtkm::CopyFlag::On)); //Set point scalar - dataSet.AddField(Field("pointvar", vtkm::cont::Field::ASSOC_POINTS, vars, nVerts)); + dataSet.AddField( + make_Field("pointvar", vtkm::cont::Field::ASSOC_POINTS, vars, nVerts, vtkm::CopyFlag::On)); //Set cell scalar vtkm::Float32 cellvar[2] = { 100.1f, 110.f }; - dataSet.AddField(Field("cellvar", vtkm::cont::Field::ASSOC_CELL_SET, "cells", cellvar, 2)); + dataSet.AddField(make_Field( + "cellvar", vtkm::cont::Field::ASSOC_CELL_SET, "cells", cellvar, 2, vtkm::CopyFlag::On)); vtkm::cont::CellSetExplicit<> cellSet("cells"); vtkm::Vec ids; @@ -695,14 +735,17 @@ inline vtkm::cont::DataSet MakeTestDataSet::Make3DExplicitDataSet3() }; vtkm::Float32 vars[nVerts] = { 10.1f, 10.1f, 10.2f, 30.2f }; - dataSet.AddCoordinateSystem(vtkm::cont::CoordinateSystem("coordinates", coordinates, nVerts)); + dataSet.AddCoordinateSystem( + vtkm::cont::make_CoordinateSystem("coordinates", coordinates, nVerts, vtkm::CopyFlag::On)); //Set point scalar - dataSet.AddField(Field("pointvar", vtkm::cont::Field::ASSOC_POINTS, vars, nVerts)); + dataSet.AddField( + make_Field("pointvar", vtkm::cont::Field::ASSOC_POINTS, vars, nVerts, vtkm::CopyFlag::On)); //Set cell scalar vtkm::Float32 cellvar[2] = { 100.1f }; - dataSet.AddField(Field("cellvar", vtkm::cont::Field::ASSOC_CELL_SET, "cells", cellvar, 1)); + dataSet.AddField(make_Field( + "cellvar", vtkm::cont::Field::ASSOC_CELL_SET, "cells", cellvar, 1, vtkm::CopyFlag::On)); vtkm::cont::CellSetExplicit<> cellSet("cells"); vtkm::Vec ids; @@ -743,15 +786,18 @@ inline vtkm::cont::DataSet MakeTestDataSet::Make3DExplicitDataSet5() vtkm::Float32 vars[nVerts] = { 10.1f, 20.1f, 30.2f, 40.2f, 50.3f, 60.2f, 70.2f, 80.3f, 90.f, 10.f, 11.f }; - dataSet.AddCoordinateSystem(vtkm::cont::CoordinateSystem("coordinates", coordinates, nVerts)); + dataSet.AddCoordinateSystem( + vtkm::cont::make_CoordinateSystem("coordinates", coordinates, nVerts, vtkm::CopyFlag::On)); //Set point scalar - dataSet.AddField(Field("pointvar", vtkm::cont::Field::ASSOC_POINTS, vars, nVerts)); + dataSet.AddField( + make_Field("pointvar", vtkm::cont::Field::ASSOC_POINTS, vars, nVerts, vtkm::CopyFlag::On)); //Set cell scalar const int nCells = 4; vtkm::Float32 cellvar[nCells] = { 100.1f, 110.f, 120.2f, 130.5f }; - dataSet.AddField(Field("cellvar", vtkm::cont::Field::ASSOC_CELL_SET, "cells", cellvar, nCells)); + dataSet.AddField(make_Field( + "cellvar", vtkm::cont::Field::ASSOC_CELL_SET, "cells", cellvar, nCells, vtkm::CopyFlag::On)); vtkm::cont::CellSetExplicit<> cellSet("cells"); vtkm::Vec ids; @@ -982,7 +1028,8 @@ inline vtkm::cont::DataSet MakeTestDataSet::Make3DExplicitDataSetCowNose() // create DataSet vtkm::cont::DataSet dataSet; - dataSet.AddCoordinateSystem(vtkm::cont::CoordinateSystem("coordinates", coordinates, nVerts)); + dataSet.AddCoordinateSystem( + vtkm::cont::make_CoordinateSystem("coordinates", coordinates, nVerts, vtkm::CopyFlag::On)); vtkm::cont::ArrayHandle connectivity; connectivity.Allocate(connectivitySize); diff --git a/vtkm/cont/testing/TestingComputeRange.h b/vtkm/cont/testing/TestingComputeRange.h index b902fb5c7..5a869a8d7 100644 --- a/vtkm/cont/testing/TestingComputeRange.h +++ b/vtkm/cont/testing/TestingComputeRange.h @@ -60,7 +60,7 @@ private: const vtkm::Id nvals = 11; T data[nvals] = { 1, 2, 3, 4, 5, -5, -4, -3, -2, -1, 0 }; std::random_shuffle(data, data + nvals); - vtkm::cont::Field field("TestField", vtkm::cont::Field::ASSOC_POINTS, data, nvals); + auto field = vtkm::cont::make_Field("TestField", vtkm::cont::Field::ASSOC_POINTS, data, nvals); vtkm::Range result; field.GetRange(&result); @@ -84,7 +84,8 @@ private: fieldData[j][i] = data[j]; } } - vtkm::cont::Field field("TestField", vtkm::cont::Field::ASSOC_POINTS, fieldData, nvals); + auto field = + vtkm::cont::make_Field("TestField", vtkm::cont::Field::ASSOC_POINTS, fieldData, nvals); vtkm::Range result[NumberOfComponents]; field.GetRange(result, CustomTypeList(), VTKM_DEFAULT_STORAGE_LIST_TAG()); diff --git a/vtkm/cont/testing/UnitTestAlgorithm.cxx b/vtkm/cont/testing/UnitTestAlgorithm.cxx new file mode 100644 index 000000000..9bb284b21 --- /dev/null +++ b/vtkm/cont/testing/UnitTestAlgorithm.cxx @@ -0,0 +1,185 @@ +//============================================================================ +// Copyright (c) Kitware, Inc. +// All rights reserved. +// See LICENSE.txt for details. +// This software is distributed WITHOUT ANY WARRANTY; without even +// the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR +// PURPOSE. See the above copyright notice for more information. +// +// Copyright 2017 National Technology & Engineering Solutions of Sandia, LLC (NTESS). +// Copyright 2017 UT-Battelle, LLC. +// Copyright 2017 Los Alamos National Security. +// +// Under the terms of Contract DE-NA0003525 with NTESS, +// the U.S. Government retains certain rights in this software. +// +// Under the terms of Contract DE-AC52-06NA25396 with Los Alamos National +// Laboratory (LANL), the U.S. Government retains certain rights in +// this software. +//============================================================================ + +#include + +#include + +#include + +namespace +{ +// The goal of this unit test is not to verify the correctness +// of the various algorithms. Since Algorithm is a header, we +// need to ensure we instatiate each algorithm in a source +// file to verify compilation. +// +static const vtkm::Id ARRAY_SIZE = 10; + +void CopyTest() +{ + vtkm::cont::ArrayHandle input; + vtkm::cont::ArrayHandle output; + vtkm::cont::ArrayHandle stencil; + + input.Allocate(ARRAY_SIZE); + output.Allocate(ARRAY_SIZE); + stencil.Allocate(ARRAY_SIZE); + + vtkm::cont::Algorithm::Copy(input, output); + vtkm::cont::Algorithm::CopyIf(input, stencil, output); + vtkm::cont::Algorithm::CopyIf(input, stencil, output, vtkm::LogicalNot()); + vtkm::cont::Algorithm::CopySubRange(input, 2, 1, output); +} + +void BoundsTest() +{ + + vtkm::cont::ArrayHandle input; + vtkm::cont::ArrayHandle output; + vtkm::cont::ArrayHandle values; + + input.Allocate(ARRAY_SIZE); + output.Allocate(ARRAY_SIZE); + values.Allocate(ARRAY_SIZE); + + vtkm::cont::Algorithm::LowerBounds(input, values, output); + vtkm::cont::Algorithm::LowerBounds(input, values, output, vtkm::Sum()); + vtkm::cont::Algorithm::LowerBounds(input, values); + + vtkm::cont::Algorithm::UpperBounds(input, values, output); + vtkm::cont::Algorithm::UpperBounds(input, values, output, vtkm::Sum()); + vtkm::cont::Algorithm::UpperBounds(input, values); +} + +void ReduceTest() +{ + + vtkm::cont::ArrayHandle input; + vtkm::cont::ArrayHandle keys; + vtkm::cont::ArrayHandle keysOut; + vtkm::cont::ArrayHandle valsOut; + + input.Allocate(ARRAY_SIZE); + keys.Allocate(ARRAY_SIZE); + keysOut.Allocate(ARRAY_SIZE); + valsOut.Allocate(ARRAY_SIZE); + + vtkm::Id result; + result = vtkm::cont::Algorithm::Reduce(input, vtkm::Id(0)); + result = vtkm::cont::Algorithm::Reduce(input, vtkm::Id(0), vtkm::Maximum()); + vtkm::cont::Algorithm::ReduceByKey(keys, input, keysOut, valsOut, vtkm::Maximum()); + (void)result; +} + +void ScanTest() +{ + + vtkm::cont::ArrayHandle input; + vtkm::cont::ArrayHandle output; + vtkm::cont::ArrayHandle keys; + + input.Allocate(ARRAY_SIZE); + output.Allocate(ARRAY_SIZE); + keys.Allocate(ARRAY_SIZE); + + vtkm::Id out; + out = vtkm::cont::Algorithm::ScanInclusive(input, output); + out = vtkm::cont::Algorithm::ScanInclusive(input, output, vtkm::Maximum()); + out = vtkm::cont::Algorithm::StreamingScanExclusive(1, input, output); + vtkm::cont::Algorithm::ScanInclusiveByKey(keys, input, output, vtkm::Maximum()); + vtkm::cont::Algorithm::ScanInclusiveByKey(keys, input, output); + out = vtkm::cont::Algorithm::ScanExclusive(input, output, vtkm::Maximum(), vtkm::Id(0)); + vtkm::cont::Algorithm::ScanExclusiveByKey(keys, input, output, vtkm::Id(0), vtkm::Maximum()); + vtkm::cont::Algorithm::ScanExclusiveByKey(keys, input, output); + (void)out; +} + +struct DummyFunctor : public vtkm::exec::FunctorBase +{ + template + VTKM_EXEC void operator()(IdType) const + { + } +}; + +void ScheduleTest() +{ + vtkm::cont::Algorithm::Schedule(DummyFunctor(), vtkm::Id(1)); + vtkm::Id3 id3(1, 1, 1); + vtkm::cont::Algorithm::Schedule(DummyFunctor(), id3); +} + +struct CompFunctor +{ + template + VTKM_EXEC_CONT bool operator()(const T& x, const T& y) const + { + return x < y; + } +}; + +void SortTest() +{ + vtkm::cont::ArrayHandle input; + vtkm::cont::ArrayHandle keys; + + input.Allocate(ARRAY_SIZE); + keys.Allocate(ARRAY_SIZE); + + vtkm::cont::Algorithm::Sort(input); + vtkm::cont::Algorithm::Sort(input, CompFunctor()); + vtkm::cont::Algorithm::SortByKey(keys, input); + vtkm::cont::Algorithm::SortByKey(keys, input, CompFunctor()); +} + +void SynchronizeTest() +{ + vtkm::cont::Algorithm::Synchronize(); +} + +void UniqueTest() +{ + vtkm::cont::ArrayHandle input; + + input.Allocate(ARRAY_SIZE); + + vtkm::cont::Algorithm::Unique(input); + vtkm::cont::Algorithm::Unique(input, CompFunctor()); +} + +void TestAll() +{ + CopyTest(); + BoundsTest(); + ReduceTest(); + ScanTest(); + ScheduleTest(); + SortTest(); + SynchronizeTest(); + UniqueTest(); +} + +} // anonymous namespace + +int UnitTestAlgorithm(int, char* []) +{ + return vtkm::cont::testing::Testing::Run(TestAll); +} diff --git a/vtkm/cont/testing/UnitTestMultiBlock.cxx b/vtkm/cont/testing/UnitTestMultiBlock.cxx index 7c01da39e..b341b9449 100644 --- a/vtkm/cont/testing/UnitTestMultiBlock.cxx +++ b/vtkm/cont/testing/UnitTestMultiBlock.cxx @@ -27,6 +27,7 @@ #include #include #include +#include #include #include #include @@ -34,6 +35,10 @@ #include #include +#if defined(VTKM_ENABLE_MPI) +#include +#endif + void DataSet_Compare(vtkm::cont::DataSet& LeftDateSet, vtkm::cont::DataSet& RightDateSet); static void MultiBlockTest() { @@ -46,7 +51,14 @@ static void MultiBlockTest() multiblock.AddBlock(TDset1); multiblock.AddBlock(TDset2); + int procsize = 1; +#if defined(VTKM_ENABLE_MPI) + procsize = vtkm::cont::EnvironmentTracker::GetCommunicator().size(); +#endif + VTKM_TEST_ASSERT(multiblock.GetNumberOfBlocks() == 2, "Incorrect number of blocks"); + VTKM_TEST_ASSERT(multiblock.GetGlobalNumberOfBlocks() == 2 * procsize, + "Incorrect number of blocks"); vtkm::cont::DataSet TestDSet = multiblock.GetBlock(0); VTKM_TEST_ASSERT(TDset1.GetNumberOfFields() == TestDSet.GetNumberOfFields(), @@ -155,7 +167,13 @@ void DataSet_Compare(vtkm::cont::DataSet& LeftDateSet, vtkm::cont::DataSet& Righ return; } -int UnitTestMultiBlock(int, char* []) +int UnitTestMultiBlock(int argc, char* argv[]) { + (void)argc; + (void)argv; +#if defined(VTKM_ENABLE_MPI) + diy::mpi::environment env(argc, argv); + vtkm::cont::EnvironmentTracker::SetCommunicator(diy::mpi::communicator(MPI_COMM_WORLD)); +#endif return vtkm::cont::testing::Testing::Run(MultiBlockTest); } diff --git a/vtkm/filter/NDEntropy.hxx b/vtkm/filter/NDEntropy.hxx index 16327043e..61662e343 100644 --- a/vtkm/filter/NDEntropy.hxx +++ b/vtkm/filter/NDEntropy.hxx @@ -18,7 +18,6 @@ // this software. //============================================================================ -#include #include #include @@ -57,11 +56,9 @@ inline VTKM_CONT vtkm::filter::Result NDEntropy::DoExecute( // Run worklet to calculate multi-variate entropy vtkm::Float64 entropy = ndEntropy.Run(device); - vtkm::cont::DataSet outputData; - std::vector entropyHandle; - entropyHandle.push_back(entropy); - outputData.AddField(vtkm::cont::Field("Entropy", vtkm::cont::Field::ASSOC_POINTS, entropyHandle)); + outputData.AddField(vtkm::cont::make_Field( + "Entropy", vtkm::cont::Field::ASSOC_POINTS, &entropy, 1, vtkm::CopyFlag::On)); //return outputData; return vtkm::filter::Result(outputData); diff --git a/vtkm/filter/testing/UnitTestFieldMetadata.cxx b/vtkm/filter/testing/UnitTestFieldMetadata.cxx index 10b2020a4..b02553dc3 100644 --- a/vtkm/filter/testing/UnitTestFieldMetadata.cxx +++ b/vtkm/filter/testing/UnitTestFieldMetadata.cxx @@ -59,7 +59,7 @@ void TestFieldTypesPoint() //verify the field helper works properly vtkm::Float32 vars[6] = { 10.1f, 20.1f, 30.1f, 40.1f, 50.1f, 60.1f }; - vtkm::cont::Field field("pointvar", vtkm::cont::Field::ASSOC_POINTS, vars, 6); + auto field = vtkm::cont::make_Field("pointvar", vtkm::cont::Field::ASSOC_POINTS, vars, 6); vtkm::filter::FieldMetadata makeMDFromField(field); VTKM_TEST_ASSERT(makeMDFromField.IsPointField() == true, "point should be a point field"); VTKM_TEST_ASSERT(makeMDFromField.IsCellField() == false, "point can't be a cell field"); @@ -74,7 +74,8 @@ void TestFieldTypesCell() //verify the field helper works properly vtkm::Float32 vars[6] = { 10.1f, 20.1f, 30.1f, 40.1f, 50.1f, 60.1f }; - vtkm::cont::Field field("pointvar", vtkm::cont::Field::ASSOC_CELL_SET, std::string(), vars, 6); + auto field = + vtkm::cont::make_Field("pointvar", vtkm::cont::Field::ASSOC_CELL_SET, std::string(), vars, 6); vtkm::filter::FieldMetadata makeMDFromField(field); VTKM_TEST_ASSERT(makeMDFromField.IsPointField() == false, "cell can't be a point field"); VTKM_TEST_ASSERT(makeMDFromField.IsCellField() == true, "cell should be a cell field"); diff --git a/vtkm/filter/testing/UnitTestHistogramFilter.cxx b/vtkm/filter/testing/UnitTestHistogramFilter.cxx index 2c6d900b1..4dab645be 100644 --- a/vtkm/filter/testing/UnitTestHistogramFilter.cxx +++ b/vtkm/filter/testing/UnitTestHistogramFilter.cxx @@ -227,23 +227,28 @@ vtkm::cont::DataSet MakeTestDataSet() dataSet.AddCoordinateSystem(vtkm::cont::CoordinateSystem("coordinates", coordinates)); // Set point scalars - dataSet.AddField( - vtkm::cont::Field("p_poisson", vtkm::cont::Field::ASSOC_POINTS, poisson, nVerts)); - dataSet.AddField(vtkm::cont::Field("p_normal", vtkm::cont::Field::ASSOC_POINTS, normal, nVerts)); - dataSet.AddField( - vtkm::cont::Field("p_chiSquare", vtkm::cont::Field::ASSOC_POINTS, chiSquare, nVerts)); - dataSet.AddField( - vtkm::cont::Field("p_uniform", vtkm::cont::Field::ASSOC_POINTS, uniform, nVerts)); + dataSet.AddField(vtkm::cont::make_Field( + "p_poisson", vtkm::cont::Field::ASSOC_POINTS, poisson, nVerts, vtkm::CopyFlag::On)); + dataSet.AddField(vtkm::cont::make_Field( + "p_normal", vtkm::cont::Field::ASSOC_POINTS, normal, nVerts, vtkm::CopyFlag::On)); + dataSet.AddField(vtkm::cont::make_Field( + "p_chiSquare", vtkm::cont::Field::ASSOC_POINTS, chiSquare, nVerts, vtkm::CopyFlag::On)); + dataSet.AddField(vtkm::cont::make_Field( + "p_uniform", vtkm::cont::Field::ASSOC_POINTS, uniform, nVerts, vtkm::CopyFlag::On)); // Set cell scalars - dataSet.AddField( - vtkm::cont::Field("c_poisson", vtkm::cont::Field::ASSOC_CELL_SET, "cells", poisson, nCells)); - dataSet.AddField( - vtkm::cont::Field("c_normal", vtkm::cont::Field::ASSOC_CELL_SET, "cells", normal, nCells)); - dataSet.AddField(vtkm::cont::Field( - "c_chiSquare", vtkm::cont::Field::ASSOC_CELL_SET, "cells", chiSquare, nCells)); - dataSet.AddField( - vtkm::cont::Field("c_uniform", vtkm::cont::Field::ASSOC_CELL_SET, "cells", poisson, nCells)); + dataSet.AddField(vtkm::cont::make_Field( + "c_poisson", vtkm::cont::Field::ASSOC_CELL_SET, "cells", poisson, nCells, vtkm::CopyFlag::On)); + dataSet.AddField(vtkm::cont::make_Field( + "c_normal", vtkm::cont::Field::ASSOC_CELL_SET, "cells", normal, nCells, vtkm::CopyFlag::On)); + dataSet.AddField(vtkm::cont::make_Field("c_chiSquare", + vtkm::cont::Field::ASSOC_CELL_SET, + "cells", + chiSquare, + nCells, + vtkm::CopyFlag::On)); + dataSet.AddField(vtkm::cont::make_Field( + "c_uniform", vtkm::cont::Field::ASSOC_CELL_SET, "cells", poisson, nCells, vtkm::CopyFlag::On)); vtkm::cont::CellSetStructured cellSet("cells"); diff --git a/vtkm/filter/testing/UnitTestNDEntropyFilter.cxx b/vtkm/filter/testing/UnitTestNDEntropyFilter.cxx index 2d37fff32..395ad1bc1 100644 --- a/vtkm/filter/testing/UnitTestNDEntropyFilter.cxx +++ b/vtkm/filter/testing/UnitTestNDEntropyFilter.cxx @@ -173,9 +173,12 @@ vtkm::cont::DataSet MakeTestDataSet() dataSet.AddCoordinateSystem(vtkm::cont::CoordinateSystem("coordinates", coordinates)); // Set point scalars - dataSet.AddField(vtkm::cont::Field("fieldA", vtkm::cont::Field::ASSOC_POINTS, fieldA, nVerts)); - dataSet.AddField(vtkm::cont::Field("fieldB", vtkm::cont::Field::ASSOC_POINTS, fieldB, nVerts)); - dataSet.AddField(vtkm::cont::Field("fieldC", vtkm::cont::Field::ASSOC_POINTS, fieldC, nVerts)); + dataSet.AddField(vtkm::cont::make_Field( + "fieldA", vtkm::cont::Field::ASSOC_POINTS, fieldA, nVerts, vtkm::CopyFlag::On)); + dataSet.AddField(vtkm::cont::make_Field( + "fieldB", vtkm::cont::Field::ASSOC_POINTS, fieldB, nVerts, vtkm::CopyFlag::On)); + dataSet.AddField(vtkm::cont::make_Field( + "fieldC", vtkm::cont::Field::ASSOC_POINTS, fieldC, nVerts, vtkm::CopyFlag::On)); return dataSet; } diff --git a/vtkm/filter/testing/UnitTestNDHistogramFilter.cxx b/vtkm/filter/testing/UnitTestNDHistogramFilter.cxx index 01d452b95..2f2cbb111 100644 --- a/vtkm/filter/testing/UnitTestNDHistogramFilter.cxx +++ b/vtkm/filter/testing/UnitTestNDHistogramFilter.cxx @@ -56,9 +56,12 @@ vtkm::cont::DataSet MakeTestDataSet() }; // Set point scalars - dataSet.AddField(vtkm::cont::Field("fieldA", vtkm::cont::Field::ASSOC_POINTS, fieldA, nVerts)); - dataSet.AddField(vtkm::cont::Field("fieldB", vtkm::cont::Field::ASSOC_POINTS, fieldB, nVerts)); - dataSet.AddField(vtkm::cont::Field("fieldC", vtkm::cont::Field::ASSOC_POINTS, fieldC, nVerts)); + dataSet.AddField(vtkm::cont::make_Field( + "fieldA", vtkm::cont::Field::ASSOC_POINTS, fieldA, nVerts, vtkm::CopyFlag::On)); + dataSet.AddField(vtkm::cont::make_Field( + "fieldB", vtkm::cont::Field::ASSOC_POINTS, fieldB, nVerts, vtkm::CopyFlag::On)); + dataSet.AddField(vtkm::cont::make_Field( + "fieldC", vtkm::cont::Field::ASSOC_POINTS, fieldC, nVerts, vtkm::CopyFlag::On)); return dataSet; } diff --git a/vtkm/filter/testing/UnitTestPointElevationFilter.cxx b/vtkm/filter/testing/UnitTestPointElevationFilter.cxx index b3ed17151..b3bb50f1d 100644 --- a/vtkm/filter/testing/UnitTestPointElevationFilter.cxx +++ b/vtkm/filter/testing/UnitTestPointElevationFilter.cxx @@ -44,7 +44,8 @@ vtkm::cont::DataSet MakePointElevationTestDataSet() } vtkm::Id numCells = (dim - 1) * (dim - 1); - dataSet.AddCoordinateSystem(vtkm::cont::CoordinateSystem("coordinates", coordinates)); + dataSet.AddCoordinateSystem( + vtkm::cont::make_CoordinateSystem("coordinates", coordinates, vtkm::CopyFlag::On)); vtkm::cont::CellSetExplicit<> cellSet("cells"); cellSet.PrepareToAddCells(numCells, numCells * 4); diff --git a/vtkm/internal/CMakeLists.txt b/vtkm/internal/CMakeLists.txt index 3c7be5426..99ed7c6b7 100755 --- a/vtkm/internal/CMakeLists.txt +++ b/vtkm/internal/CMakeLists.txt @@ -28,6 +28,7 @@ set(VTKM_USE_64BIT_IDS ${VTKm_USE_64BIT_IDS}) set(VTKM_ENABLE_CUDA ${VTKm_ENABLE_CUDA}) set(VTKM_ENABLE_TBB ${VTKm_ENABLE_TBB}) +set(VTKM_ENABLE_MPI ${VTKm_ENABLE_MPI}) vtkm_get_kit_name(kit_name kit_dir) configure_file(${CMAKE_CURRENT_SOURCE_DIR}/Configure.h.in diff --git a/vtkm/internal/Configure.h.in b/vtkm/internal/Configure.h.in index 55a699f1a..95b166b63 100644 --- a/vtkm/internal/Configure.h.in +++ b/vtkm/internal/Configure.h.in @@ -263,6 +263,9 @@ #cmakedefine VTKM_ENABLE_TBB #endif +//Mark if we are building with MPI enabled. +#cmakedefine VTKM_ENABLE_MPI + #if __cplusplus >= 201103L || \ ( defined(VTKM_MSVC) && _MSC_VER >= 1800 ) || \ ( defined(VTKM_ICC) && defined(__INTEL_CXX11_MODE__) ) diff --git a/vtkm/internal/ListTagDetail.h b/vtkm/internal/ListTagDetail.h index 0d812f609..da6d28f6b 100644 --- a/vtkm/internal/ListTagDetail.h +++ b/vtkm/internal/ListTagDetail.h @@ -210,31 +210,47 @@ VTKM_CONT void ListForEachImpl(Functor&& f, std::forward(f), brigand::list{}, std::forward(args)...); } - -template -struct ListCrossProductAppend -{ - using type = brigand::push_back>; -}; - -template -struct ListCrossProductImplUnrollR2 -{ - using P = - brigand::fold, - ListCrossProductAppend>>; - - using type = brigand::append; -}; - template struct ListCrossProductImpl { - using type = brigand::fold< - R2, - brigand::list<>, - ListCrossProductImplUnrollR2>>; +#if defined(VTKM_MSVC) && _MSC_VER == 1800 + // This is a Cartesian product generator that is used + // when building with visual studio 2013. Visual Studio + // 2013 is unable to handle the lazy version as it can't + // deduce the correct template parameters + using type = brigand::reverse_fold< + brigand::list, + brigand::list>, + brigand::bind< + brigand::join, + brigand::bind< + brigand::transform, + brigand::_2, + brigand::defer, + brigand::defer>>>>>>>>>; +#else + // This is a lazy Cartesian product generator that is used + // when using any compiler other than visual studio 2013. + // This version was settled on as being the best default + // version as all compilers including Intel handle this + // implementation without issue for very large cross products + using type = brigand::reverse_fold< + brigand::list, + brigand::list>, + brigand::lazy::join, + brigand::defer>>>>>>>>>; +#endif }; diff --git a/vtkm/internal/brigand.hpp b/vtkm/internal/brigand.hpp index 04771b44b..f0769e643 100644 --- a/vtkm/internal/brigand.hpp +++ b/vtkm/internal/brigand.hpp @@ -235,6 +235,19 @@ namespace detail } namespace brigand { + template + struct same + { + using type = T; + }; + template + struct same + { + static_assert(std::is_same::value, ""); + using type = T; + }; + + namespace detail { template struct element_at; @@ -243,21 +256,20 @@ namespace brigand { template type_ static at(Ts..., type_*, ...); - //CUDA 9 version that is required - template type_ static at_with_type(Ts..., R, Other...); + //CUDA 9 and Intel 18 version that is required + template T static at_with_type(Ts..., T*, Other...); }; - template T extract_type(type_*); - template struct at_impl; -#if defined(BRIGAND_COMP_CUDA_9) - //Only needed for CUDA 9 RC1 as it has some compiler bugs +#if defined(BRIGAND_COMP_CUDA_9) || defined(BRIGAND_COMP_INTEL) + //Both CUDA 9 and the Intel 18 compiler series have a problem deducing the + //type so we are just going template class L, class... Ts> struct at_impl> { using base_with_type = decltype( element_at>::at_with_type(static_cast*>(nullptr)...)); - using type = decltype(extract_type(typename base_with_type::type{})); + using type = typename base_with_type::type; }; #else // This is the original implementation diff --git a/vtkm/io/writer/VTKDataSetWriter.h b/vtkm/io/writer/VTKDataSetWriter.h index 180731059..91ddc9586 100644 --- a/vtkm/io/writer/VTKDataSetWriter.h +++ b/vtkm/io/writer/VTKDataSetWriter.h @@ -193,8 +193,9 @@ private: vtkm::Id nids = cellSet.GetNumberOfPointsInCell(i); cellSet.GetIndices(i, ids); out << nids; + auto IdPortal = ids.GetPortalConstControl(); for (int j = 0; j < nids; ++j) - out << " " << ids.GetPortalControl().Get(j); + out << " " << IdPortal.Get(j); out << std::endl; } diff --git a/vtkm/rendering/MapperRayTracer.cxx b/vtkm/rendering/MapperRayTracer.cxx index 84e699bb2..fa0dae248 100644 --- a/vtkm/rendering/MapperRayTracer.cxx +++ b/vtkm/rendering/MapperRayTracer.cxx @@ -107,9 +107,9 @@ void MapperRayTracer::RenderCells(const vtkm::cont::DynamicCellSet& cellset, this->Internals->Rays, camera, *this->Internals->Canvas); vtkm::Bounds dataBounds = coords.GetBounds(); - + vtkm::cont::Field& field = const_cast(scalarField); this->Internals->Tracer.SetData( - coords.GetData(), indices, scalarField, numberOfTriangles, scalarRange, dataBounds); + coords.GetData(), indices, field, numberOfTriangles, scalarRange, dataBounds); this->Internals->Tracer.SetColorMap(this->ColorMap); this->Internals->Tracer.Render(this->Internals->Rays); diff --git a/vtkm/rendering/raytracing/RayTracer.cxx b/vtkm/rendering/raytracing/RayTracer.cxx index 6b06eecfb..cbf461880 100644 --- a/vtkm/rendering/raytracing/RayTracer.cxx +++ b/vtkm/rendering/raytracing/RayTracer.cxx @@ -234,14 +234,14 @@ public: VTKM_CONT void run(Ray& rays, LinearBVH& bvh, vtkm::cont::DynamicArrayHandleCoordinateSystem& coordsHandle, - const vtkm::cont::Field* scalarField, + vtkm::cont::Field& scalarField, const vtkm::Range& scalarRange) { - bool isSupportedField = (scalarField->GetAssociation() == vtkm::cont::Field::ASSOC_POINTS || - scalarField->GetAssociation() == vtkm::cont::Field::ASSOC_CELL_SET); + bool isSupportedField = (scalarField.GetAssociation() == vtkm::cont::Field::ASSOC_POINTS || + scalarField.GetAssociation() == vtkm::cont::Field::ASSOC_CELL_SET); if (!isSupportedField) throw vtkm::cont::ErrorBadValue("Field not accociated with cell set or points"); - bool isAssocPoints = scalarField->GetAssociation() == vtkm::cont::Field::ASSOC_POINTS; + bool isAssocPoints = scalarField.GetAssociation() == vtkm::cont::Field::ASSOC_POINTS; vtkm::worklet::DispatcherMapField(CalculateNormals(bvh.LeafNodes)) .Invoke(rays.HitIdx, rays.Dir, rays.NormalX, rays.NormalY, rays.NormalZ, coordsHandle); @@ -251,14 +251,14 @@ public: vtkm::worklet::DispatcherMapField, Device>( LerpScalar( bvh.LeafNodes, vtkm::Float32(scalarRange.Min), vtkm::Float32(scalarRange.Max))) - .Invoke(rays.HitIdx, rays.U, rays.V, rays.Scalar, *scalarField); + .Invoke(rays.HitIdx, rays.U, rays.V, rays.Scalar, scalarField); } else { vtkm::worklet::DispatcherMapField, Device>( NodalScalar( bvh.LeafNodes, vtkm::Float32(scalarRange.Min), vtkm::Float32(scalarRange.Max))) - .Invoke(rays.HitIdx, rays.Scalar, *scalarField); + .Invoke(rays.HitIdx, rays.Scalar, scalarField); } } // Run @@ -398,14 +398,14 @@ Camera& RayTracer::GetCamera() void RayTracer::SetData(const vtkm::cont::DynamicArrayHandleCoordinateSystem& coordsHandle, const vtkm::cont::ArrayHandle>& indices, - const vtkm::cont::Field& scalarField, + vtkm::cont::Field& scalarField, const vtkm::Id& numberOfTriangles, const vtkm::Range& scalarRange, const vtkm::Bounds& dataBounds) { CoordsHandle = coordsHandle; Indices = indices; - ScalarField = &scalarField; + ScalarField = scalarField; NumberOfTriangles = numberOfTriangles; ScalarRange = scalarRange; DataBounds = dataBounds; diff --git a/vtkm/rendering/raytracing/RayTracer.h b/vtkm/rendering/raytracing/RayTracer.h index 7555dc394..8b1c20b74 100644 --- a/vtkm/rendering/raytracing/RayTracer.h +++ b/vtkm/rendering/raytracing/RayTracer.h @@ -34,13 +34,13 @@ namespace rendering namespace raytracing { -class RayTracer +class VTKM_RENDERING_EXPORT RayTracer { protected: LinearBVH Bvh; Camera camera; vtkm::cont::DynamicArrayHandleCoordinateSystem CoordsHandle; - const vtkm::cont::Field* ScalarField; + vtkm::cont::Field ScalarField; vtkm::cont::ArrayHandle> Indices; vtkm::cont::ArrayHandle Scalars; vtkm::Id NumberOfTriangles; @@ -63,7 +63,7 @@ public: VTKM_CONT void SetData(const vtkm::cont::DynamicArrayHandleCoordinateSystem& coordsHandle, const vtkm::cont::ArrayHandle>& indices, - const vtkm::cont::Field& scalarField, + vtkm::cont::Field& scalarField, const vtkm::Id& numberOfTriangles, const vtkm::Range& scalarRange, const vtkm::Bounds& dataBounds); diff --git a/vtkm/rendering/testing/UnitTestMapperWireframer.cxx b/vtkm/rendering/testing/UnitTestMapperWireframer.cxx index 5319f88a1..5ba8ef6b4 100644 --- a/vtkm/rendering/testing/UnitTestMapperWireframer.cxx +++ b/vtkm/rendering/testing/UnitTestMapperWireframer.cxx @@ -71,7 +71,8 @@ vtkm::cont::DataSet Make2DExplicitDataSet() pointVar.push_back(13); pointVar.push_back(14); pointVar.push_back(15); - dataSet.AddCoordinateSystem(vtkm::cont::CoordinateSystem("coordinates", coordinates, nVerts)); + dataSet.AddCoordinateSystem( + vtkm::cont::make_CoordinateSystem("coordinates", coordinates, nVerts, vtkm::CopyFlag::On)); vtkm::cont::CellSetSingleType<> cellSet("cells"); vtkm::cont::ArrayHandle connectivity; diff --git a/vtkm/testing/UnitTestListTag.cxx b/vtkm/testing/UnitTestListTag.cxx index 3e1a68d82..1971d4f58 100644 --- a/vtkm/testing/UnitTestListTag.cxx +++ b/vtkm/testing/UnitTestListTag.cxx @@ -68,7 +68,7 @@ struct TestListTagUniversal : vtkm::ListTagUniversal }; template -std::pair test_number(std::pair, TestClass>) +std::pair test_number(brigand::list, TestClass>) { return std::make_pair(N, M); } diff --git a/vtkm/worklet/CMakeLists.txt b/vtkm/worklet/CMakeLists.txt index fce818a58..4a70203bf 100644 --- a/vtkm/worklet/CMakeLists.txt +++ b/vtkm/worklet/CMakeLists.txt @@ -48,6 +48,7 @@ set(headers NDimsEntropy.h NDimsHistogram.h NDimsHistMarginalization.h + Normalize.h ParticleAdvection.h PointAverage.h PointElevation.h diff --git a/vtkm/worklet/Normalize.h b/vtkm/worklet/Normalize.h new file mode 100644 index 000000000..40ff5317f --- /dev/null +++ b/vtkm/worklet/Normalize.h @@ -0,0 +1,60 @@ +//============================================================================ +// Copyright (c) Kitware, Inc. +// All rights reserved. +// See LICENSE.txt for details. +// This software is distributed WITHOUT ANY WARRANTY; without even +// the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR +// PURPOSE. See the above copyright notice for more information. +// +// Copyright 2014 National Technology & Engineering Solutions of Sandia, LLC (NTESS). +// Copyright 2014 UT-Battelle, LLC. +// Copyright 2014 Los Alamos National Security. +// +// Under the terms of Contract DE-NA0003525 with NTESS, +// the U.S. Government retains certain rights in this software. +// +// Under the terms of Contract DE-AC52-06NA25396 with Los Alamos National +// Laboratory (LANL), the U.S. Government retains certain rights in +// this software. +//============================================================================ +#ifndef vtk_m_worklet_Normalize_h +#define vtk_m_worklet_Normalize_h + +#include + +#include + +namespace vtkm +{ +namespace worklet +{ + +class Normal : public vtkm::worklet::WorkletMapField +{ +public: + typedef void ControlSignature(FieldIn, FieldOut); + typedef void ExecutionSignature(_1, _2); + + template + VTKM_EXEC void operator()(const T& inValue, T2& outValue) const + { + outValue = vtkm::Normal(inValue); + } +}; + +class Normalize : public vtkm::worklet::WorkletMapField +{ +public: + typedef void ControlSignature(FieldInOut); + typedef void ExecutionSignature(_1); + + template + VTKM_EXEC void operator()(T& value) const + { + vtkm::Normalize(value); + } +}; +} +} // namespace vtkm::worklet + +#endif // vtk_m_worklet_Normalize_h diff --git a/vtkm/worklet/testing/CMakeLists.txt b/vtkm/worklet/testing/CMakeLists.txt index bdd64a3b9..034e14165 100644 --- a/vtkm/worklet/testing/CMakeLists.txt +++ b/vtkm/worklet/testing/CMakeLists.txt @@ -38,6 +38,7 @@ set(unit_tests UnitTestMarchingCubes.cxx UnitTestMask.cxx UnitTestMaskPoints.cxx + UnitTestNormalize.cxx UnitTestNDimsEntropy.cxx UnitTestNDimsHistogram.cxx UnitTestNDimsHistMarginalization.cxx diff --git a/vtkm/worklet/testing/UnitTestFieldHistogram.cxx b/vtkm/worklet/testing/UnitTestFieldHistogram.cxx index 6c751ea0c..775c33aa5 100644 --- a/vtkm/worklet/testing/UnitTestFieldHistogram.cxx +++ b/vtkm/worklet/testing/UnitTestFieldHistogram.cxx @@ -227,23 +227,28 @@ vtkm::cont::DataSet MakeTestDataSet() dataSet.AddCoordinateSystem(vtkm::cont::CoordinateSystem("coordinates", coordinates)); // Set point scalars - dataSet.AddField( - vtkm::cont::Field("p_poisson", vtkm::cont::Field::ASSOC_POINTS, poisson, nVerts)); - dataSet.AddField(vtkm::cont::Field("p_normal", vtkm::cont::Field::ASSOC_POINTS, normal, nVerts)); - dataSet.AddField( - vtkm::cont::Field("p_chiSquare", vtkm::cont::Field::ASSOC_POINTS, chiSquare, nVerts)); - dataSet.AddField( - vtkm::cont::Field("p_uniform", vtkm::cont::Field::ASSOC_POINTS, uniform, nVerts)); + dataSet.AddField(vtkm::cont::make_Field( + "p_poisson", vtkm::cont::Field::ASSOC_POINTS, poisson, nVerts, vtkm::CopyFlag::On)); + dataSet.AddField(vtkm::cont::make_Field( + "p_normal", vtkm::cont::Field::ASSOC_POINTS, normal, nVerts, vtkm::CopyFlag::On)); + dataSet.AddField(vtkm::cont::make_Field( + "p_chiSquare", vtkm::cont::Field::ASSOC_POINTS, chiSquare, nVerts, vtkm::CopyFlag::On)); + dataSet.AddField(vtkm::cont::make_Field( + "p_uniform", vtkm::cont::Field::ASSOC_POINTS, uniform, nVerts, vtkm::CopyFlag::On)); // Set cell scalars - dataSet.AddField( - vtkm::cont::Field("c_poisson", vtkm::cont::Field::ASSOC_CELL_SET, "cells", poisson, nCells)); - dataSet.AddField( - vtkm::cont::Field("c_normal", vtkm::cont::Field::ASSOC_CELL_SET, "cells", normal, nCells)); - dataSet.AddField(vtkm::cont::Field( - "c_chiSquare", vtkm::cont::Field::ASSOC_CELL_SET, "cells", chiSquare, nCells)); - dataSet.AddField( - vtkm::cont::Field("c_uniform", vtkm::cont::Field::ASSOC_CELL_SET, "cells", poisson, nCells)); + dataSet.AddField(vtkm::cont::make_Field( + "c_poisson", vtkm::cont::Field::ASSOC_CELL_SET, "cells", poisson, nCells, vtkm::CopyFlag::On)); + dataSet.AddField(vtkm::cont::make_Field( + "c_normal", vtkm::cont::Field::ASSOC_CELL_SET, "cells", normal, nCells, vtkm::CopyFlag::On)); + dataSet.AddField(vtkm::cont::make_Field("c_chiSquare", + vtkm::cont::Field::ASSOC_CELL_SET, + "cells", + chiSquare, + nCells, + vtkm::CopyFlag::On)); + dataSet.AddField(vtkm::cont::make_Field( + "c_uniform", vtkm::cont::Field::ASSOC_CELL_SET, "cells", poisson, nCells, vtkm::CopyFlag::On)); vtkm::cont::CellSetStructured cellSet("cells"); diff --git a/vtkm/worklet/testing/UnitTestFieldStatistics.cxx b/vtkm/worklet/testing/UnitTestFieldStatistics.cxx index a0cc6d960..74792bf21 100644 --- a/vtkm/worklet/testing/UnitTestFieldStatistics.cxx +++ b/vtkm/worklet/testing/UnitTestFieldStatistics.cxx @@ -46,8 +46,8 @@ vtkm::cont::DataSet Make2DUniformStatDataSet0() // Create cell scalar vtkm::Float32 data[nVerts] = { 4, 1, 10, 6, 8, 2, 9, 3, 5, 7 }; - dataSet.AddField( - vtkm::cont::Field("data", vtkm::cont::Field::ASSOC_CELL_SET, "cells", data, nCells)); + dataSet.AddField(vtkm::cont::make_Field( + "data", vtkm::cont::Field::ASSOC_CELL_SET, "cells", data, nCells, vtkm::CopyFlag::On)); vtkm::cont::CellSetStructured cellSet("cells"); @@ -262,23 +262,28 @@ vtkm::cont::DataSet Make2DUniformStatDataSet1() dataSet.AddCoordinateSystem(vtkm::cont::CoordinateSystem("coordinates", coordinates)); // Set point scalars - dataSet.AddField( - vtkm::cont::Field("p_poisson", vtkm::cont::Field::ASSOC_POINTS, poisson, nVerts)); - dataSet.AddField(vtkm::cont::Field("p_normal", vtkm::cont::Field::ASSOC_POINTS, normal, nVerts)); - dataSet.AddField( - vtkm::cont::Field("p_chiSquare", vtkm::cont::Field::ASSOC_POINTS, chiSquare, nVerts)); - dataSet.AddField( - vtkm::cont::Field("p_uniform", vtkm::cont::Field::ASSOC_POINTS, uniform, nVerts)); + dataSet.AddField(vtkm::cont::make_Field( + "p_poisson", vtkm::cont::Field::ASSOC_POINTS, poisson, nVerts, vtkm::CopyFlag::On)); + dataSet.AddField(vtkm::cont::make_Field( + "p_normal", vtkm::cont::Field::ASSOC_POINTS, normal, nVerts, vtkm::CopyFlag::On)); + dataSet.AddField(vtkm::cont::make_Field( + "p_chiSquare", vtkm::cont::Field::ASSOC_POINTS, chiSquare, nVerts, vtkm::CopyFlag::On)); + dataSet.AddField(vtkm::cont::make_Field( + "p_uniform", vtkm::cont::Field::ASSOC_POINTS, uniform, nVerts, vtkm::CopyFlag::On)); // Set cell scalars - dataSet.AddField( - vtkm::cont::Field("c_poisson", vtkm::cont::Field::ASSOC_CELL_SET, "cells", poisson, nCells)); - dataSet.AddField( - vtkm::cont::Field("c_normal", vtkm::cont::Field::ASSOC_CELL_SET, "cells", normal, nCells)); - dataSet.AddField(vtkm::cont::Field( - "c_chiSquare", vtkm::cont::Field::ASSOC_CELL_SET, "cells", chiSquare, nCells)); - dataSet.AddField( - vtkm::cont::Field("c_uniform", vtkm::cont::Field::ASSOC_CELL_SET, "cells", poisson, nCells)); + dataSet.AddField(vtkm::cont::make_Field( + "c_poisson", vtkm::cont::Field::ASSOC_CELL_SET, "cells", poisson, nCells, vtkm::CopyFlag::On)); + dataSet.AddField(vtkm::cont::make_Field( + "c_normal", vtkm::cont::Field::ASSOC_CELL_SET, "cells", normal, nCells, vtkm::CopyFlag::On)); + dataSet.AddField(vtkm::cont::make_Field("c_chiSquare", + vtkm::cont::Field::ASSOC_CELL_SET, + "cells", + chiSquare, + nCells, + vtkm::CopyFlag::On)); + dataSet.AddField(vtkm::cont::make_Field( + "c_uniform", vtkm::cont::Field::ASSOC_CELL_SET, "cells", poisson, nCells, vtkm::CopyFlag::On)); vtkm::cont::CellSetStructured cellSet("cells"); diff --git a/vtkm/worklet/testing/UnitTestNDimsEntropy.cxx b/vtkm/worklet/testing/UnitTestNDimsEntropy.cxx index 1b432b21a..b30ede0c6 100644 --- a/vtkm/worklet/testing/UnitTestNDimsEntropy.cxx +++ b/vtkm/worklet/testing/UnitTestNDimsEntropy.cxx @@ -173,9 +173,12 @@ vtkm::cont::DataSet MakeTestDataSet() dataSet.AddCoordinateSystem(vtkm::cont::CoordinateSystem("coordinates", coordinates)); // Set point scalars - dataSet.AddField(vtkm::cont::Field("fieldA", vtkm::cont::Field::ASSOC_POINTS, fieldA, nVerts)); - dataSet.AddField(vtkm::cont::Field("fieldB", vtkm::cont::Field::ASSOC_POINTS, fieldB, nVerts)); - dataSet.AddField(vtkm::cont::Field("fieldC", vtkm::cont::Field::ASSOC_POINTS, fieldC, nVerts)); + dataSet.AddField(vtkm::cont::make_Field( + "fieldA", vtkm::cont::Field::ASSOC_POINTS, fieldA, nVerts, vtkm::CopyFlag::On)); + dataSet.AddField(vtkm::cont::make_Field( + "fieldB", vtkm::cont::Field::ASSOC_POINTS, fieldB, nVerts, vtkm::CopyFlag::On)); + dataSet.AddField(vtkm::cont::make_Field( + "fieldC", vtkm::cont::Field::ASSOC_POINTS, fieldC, nVerts, vtkm::CopyFlag::On)); return dataSet; } diff --git a/vtkm/worklet/testing/UnitTestNDimsHistMarginalization.cxx b/vtkm/worklet/testing/UnitTestNDimsHistMarginalization.cxx index 218e05068..3417a4425 100644 --- a/vtkm/worklet/testing/UnitTestNDimsHistMarginalization.cxx +++ b/vtkm/worklet/testing/UnitTestNDimsHistMarginalization.cxx @@ -175,9 +175,12 @@ vtkm::cont::DataSet MakeTestDataSet() dataSet.AddCoordinateSystem(vtkm::cont::CoordinateSystem("coordinates", coordinates)); // Set point scalars - dataSet.AddField(vtkm::cont::Field("fieldA", vtkm::cont::Field::ASSOC_POINTS, fieldA, nVerts)); - dataSet.AddField(vtkm::cont::Field("fieldB", vtkm::cont::Field::ASSOC_POINTS, fieldB, nVerts)); - dataSet.AddField(vtkm::cont::Field("fieldC", vtkm::cont::Field::ASSOC_POINTS, fieldC, nVerts)); + dataSet.AddField(vtkm::cont::make_Field( + "fieldA", vtkm::cont::Field::ASSOC_POINTS, fieldA, nVerts, vtkm::CopyFlag::On)); + dataSet.AddField(vtkm::cont::make_Field( + "fieldB", vtkm::cont::Field::ASSOC_POINTS, fieldB, nVerts, vtkm::CopyFlag::On)); + dataSet.AddField(vtkm::cont::make_Field( + "fieldC", vtkm::cont::Field::ASSOC_POINTS, fieldC, nVerts, vtkm::CopyFlag::On)); return dataSet; } diff --git a/vtkm/worklet/testing/UnitTestNDimsHistogram.cxx b/vtkm/worklet/testing/UnitTestNDimsHistogram.cxx index 24b188aef..39329afc7 100644 --- a/vtkm/worklet/testing/UnitTestNDimsHistogram.cxx +++ b/vtkm/worklet/testing/UnitTestNDimsHistogram.cxx @@ -56,9 +56,12 @@ vtkm::cont::DataSet MakeTestDataSet() }; // Set point scalars - dataSet.AddField(vtkm::cont::Field("fieldA", vtkm::cont::Field::ASSOC_POINTS, fieldA, nVerts)); - dataSet.AddField(vtkm::cont::Field("fieldB", vtkm::cont::Field::ASSOC_POINTS, fieldB, nVerts)); - dataSet.AddField(vtkm::cont::Field("fieldC", vtkm::cont::Field::ASSOC_POINTS, fieldC, nVerts)); + dataSet.AddField(vtkm::cont::make_Field( + "fieldA", vtkm::cont::Field::ASSOC_POINTS, fieldA, nVerts, vtkm::CopyFlag::On)); + dataSet.AddField(vtkm::cont::make_Field( + "fieldB", vtkm::cont::Field::ASSOC_POINTS, fieldB, nVerts, vtkm::CopyFlag::On)); + dataSet.AddField(vtkm::cont::make_Field( + "fieldC", vtkm::cont::Field::ASSOC_POINTS, fieldC, nVerts, vtkm::CopyFlag::On)); return dataSet; } diff --git a/vtkm/worklet/testing/UnitTestNormalize.cxx b/vtkm/worklet/testing/UnitTestNormalize.cxx new file mode 100644 index 000000000..861b56b33 --- /dev/null +++ b/vtkm/worklet/testing/UnitTestNormalize.cxx @@ -0,0 +1,145 @@ +//============================================================================ +// Copyright (c) Kitware, Inc. +// All rights reserved. +// See LICENSE.txt for details. +// This software is distributed WITHOUT ANY WARRANTY; without even +// the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR +// PURPOSE. See the above copyright notice for more information. +// +// Copyright 2014 National Technology & Engineering Solutions of Sandia, LLC (NTESS). +// Copyright 2014 UT-Battelle, LLC. +// Copyright 2014 Los Alamos National Security. +// +// Under the terms of Contract DE-NA0003525 with NTESS, +// the U.S. Government retains certain rights in this software. +// +// Under the terms of Contract DE-AC52-06NA25396 with Los Alamos National +// Laboratory (LANL), the U.S. Government retains certain rights in +// this software. +//============================================================================ + +#include +#include + +#include + +namespace +{ + +template +void createVectors(std::vector>& vecs) +{ + vecs.push_back(vtkm::make_Vec(2, 0, 0)); + vecs.push_back(vtkm::make_Vec(0, 2, 0)); + vecs.push_back(vtkm::make_Vec(0, 0, 2)); + vecs.push_back(vtkm::make_Vec(1, 1, 1)); + vecs.push_back(vtkm::make_Vec(2, 2, 2)); + vecs.push_back(vtkm::make_Vec(2, 1, 1)); + + vecs.push_back(vtkm::make_Vec(1000000, 0, 0)); + + vecs.push_back(vtkm::make_Vec(static_cast(.1), static_cast(0), static_cast(0))); + vecs.push_back(vtkm::make_Vec(static_cast(.001), static_cast(0), static_cast(0))); +} + +template +void createVectors(std::vector>& vecs) +{ + vecs.push_back(vtkm::make_Vec(1, 0)); + vecs.push_back(vtkm::make_Vec(0, 1)); + vecs.push_back(vtkm::make_Vec(1, 1)); + vecs.push_back(vtkm::make_Vec(2, 0)); + vecs.push_back(vtkm::make_Vec(0, 2)); + vecs.push_back(vtkm::make_Vec(2, 2)); + + vecs.push_back(vtkm::make_Vec(1000000, 0)); + + vecs.push_back(vtkm::make_Vec(static_cast(.1), static_cast(0))); + vecs.push_back(vtkm::make_Vec(static_cast(.001), static_cast(0))); +} + +template +void TestNormal() +{ + std::vector> inputVecs; + createVectors(inputVecs); + + vtkm::cont::ArrayHandle> inputArray; + vtkm::cont::ArrayHandle> outputArray; + inputArray = vtkm::cont::make_ArrayHandle(inputVecs); + + vtkm::worklet::Normal normalWorklet; + vtkm::worklet::DispatcherMapField dispatcherNormal(normalWorklet); + dispatcherNormal.Invoke(inputArray, outputArray); + + //Validate results. + + //Make sure the number of values match. + VTKM_TEST_ASSERT(outputArray.GetNumberOfValues() == inputArray.GetNumberOfValues(), + "Wrong number of results for Normalize worklet"); + + //Make sure each vector is correct. + for (vtkm::Id i = 0; i < inputArray.GetNumberOfValues(); i++) + { + //Make sure that the value is correct. + vtkm::Vec v = inputArray.GetPortalConstControl().Get(i); + vtkm::Vec vN = outputArray.GetPortalConstControl().Get(i); + T len = vtkm::Magnitude(v); + VTKM_TEST_ASSERT(test_equal(v / len, vN), "Wrong result for Normalize worklet"); + + //Make sure the magnitudes are all 1.0 + len = vtkm::Magnitude(vN); + VTKM_TEST_ASSERT(test_equal(len, 1), "Wrong magnitude for Normalize worklet"); + } +} + +template +void TestNormalize() +{ + std::vector> inputVecs; + createVectors(inputVecs); + + vtkm::cont::ArrayHandle> inputArray; + vtkm::cont::ArrayHandle> outputArray; + inputArray = vtkm::cont::make_ArrayHandle(inputVecs); + + vtkm::worklet::Normalize normalizeWorklet; + vtkm::worklet::DispatcherMapField dispatcherNormalize(normalizeWorklet); + dispatcherNormalize.Invoke(inputArray); + + //Make sure each vector is correct. + for (vtkm::Id i = 0; i < inputArray.GetNumberOfValues(); i++) + { + //Make sure that the value is correct. + vtkm::Vec v = inputVecs[static_cast(i)]; + vtkm::Vec vN = inputArray.GetPortalConstControl().Get(i); + T len = vtkm::Magnitude(v); + VTKM_TEST_ASSERT(test_equal(v / len, vN), "Wrong result for Normalize worklet"); + + //Make sure the magnitudes are all 1.0 + len = vtkm::Magnitude(vN); + VTKM_TEST_ASSERT(test_equal(len, 1), "Wrong magnitude for Normalize worklet"); + } +} + +void TestNormalWorklets() +{ + std::cout << "Testing Normal Worklet" << std::endl; + + TestNormal(); + TestNormal(); + TestNormal(); + TestNormal(); + + std::cout << "Testing Normalize Worklet" << std::endl; + TestNormalize(); + TestNormalize(); + TestNormalize(); + TestNormalize(); +} +} + +int UnitTestNormalize(int, char* []) +{ + return vtkm::cont::testing::Testing::Run(TestNormalWorklets); +} diff --git a/vtkm/worklet/testing/UnitTestPointElevation.cxx b/vtkm/worklet/testing/UnitTestPointElevation.cxx index ad46234e1..965f9d75d 100644 --- a/vtkm/worklet/testing/UnitTestPointElevation.cxx +++ b/vtkm/worklet/testing/UnitTestPointElevation.cxx @@ -49,7 +49,8 @@ vtkm::cont::DataSet MakePointElevationTestDataSet() } vtkm::Id numCells = (dim - 1) * (dim - 1); - dataSet.AddCoordinateSystem(vtkm::cont::CoordinateSystem("coordinates", coordinates)); + dataSet.AddCoordinateSystem( + vtkm::cont::make_CoordinateSystem("coordinates", coordinates, vtkm::CopyFlag::On)); vtkm::cont::CellSetExplicit<> cellSet("cells"); cellSet.PrepareToAddCells(numCells, numCells * 4); diff --git a/vtkm/worklet/testing/UnitTestPointGradient.cxx b/vtkm/worklet/testing/UnitTestPointGradient.cxx index 2c6827297..bca45e3cc 100644 --- a/vtkm/worklet/testing/UnitTestPointGradient.cxx +++ b/vtkm/worklet/testing/UnitTestPointGradient.cxx @@ -226,11 +226,11 @@ void TestPointGradientExplicit() void TestPointGradient() { using DeviceAdapter = VTKM_DEFAULT_DEVICE_ADAPTER_TAG; - // TestPointGradientUniform2D(); + TestPointGradientUniform2D(); TestPointGradientUniform3D(); - // TestPointGradientUniform3DWithVectorField(); - // TestPointGradientUniform3DWithVectorField2(); - // TestPointGradientExplicit(); + TestPointGradientUniform3DWithVectorField(); + TestPointGradientUniform3DWithVectorField2(); + TestPointGradientExplicit(); } }