diff --git a/.github/ISSUE_TEMPLATE/---document-issue-.md b/.github/ISSUE_TEMPLATE/---document-issue-.md
index 7c464ac584bc87cb16a796bf41acdcd79b8bd6f0..ffc2fcd7817b64584637a646edf5907612a7bbaf 100644
--- a/.github/ISSUE_TEMPLATE/---document-issue-.md
+++ b/.github/ISSUE_TEMPLATE/---document-issue-.md
@@ -56,4 +56,4 @@ For example: no sample code; The sample code is not helpful; The sample code not
For example:Chinese API in this doc is inconsistent with English API, including params, description, sample code, formula, etc.
#### Other
-For example: The doc link is broken; The doc page is missing; Dead link in docs.
\ No newline at end of file
+For example: The doc link is broken; The doc page is missing; Dead link in docs.
diff --git a/CMakeLists.txt b/CMakeLists.txt
index f24513d605c49b608cb32425a861448a3acd6c6a..f30671bd3a87e87732b3a047e91811452370e06e 100755
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -13,6 +13,7 @@
# limitations under the License
cmake_minimum_required(VERSION 3.10)
+cmake_policy(VERSION 3.10)
set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} "${CMAKE_CURRENT_SOURCE_DIR}/cmake")
set(PADDLE_SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR})
set(PADDLE_BINARY_DIR ${CMAKE_CURRENT_BINARY_DIR})
@@ -21,9 +22,6 @@ include(system)
project(paddle CXX C)
-include(init)
-include(generic) # simplify cmake module
-
# enable language CUDA
# TODO(Shibo Tao): remove find_package(CUDA) completely.
find_package(CUDA QUIET)
@@ -32,16 +30,23 @@ option(WITH_TENSORRT "Compile PaddlePaddle with NVIDIA TensorRT" OFF)
option(WITH_XPU "Compile PaddlePaddle with BAIDU KUNLUN XPU" OFF)
option(WITH_WIN_DUMP_DBG "Compile with windows core dump debug mode" OFF)
option(WITH_ASCEND "Compile PaddlePaddle with ASCEND" OFF)
+option(WITH_ROCM "Compile PaddlePaddle with ROCM platform" OFF)
+# NOTE(zhiqiu): WITH_ASCEND_CL can be compile on x86_64, so we can set WITH_ASCEND=OFF and WITH_ASCEND_CL=ON
+# to develop some acl related functionality on x86
+option(WITH_ASCEND_CL "Compile PaddlePaddle with ASCEND CL" ${WITH_ASCEND})
+option(WITH_ASCEND_CXX11 "Compile PaddlePaddle with ASCEND and CXX11 ABI" OFF)
+# Note(zhouwei): It use option above, so put here
+include(init)
+include(generic) # simplify cmake module
+
if (WITH_GPU AND WITH_XPU)
message(FATAL_ERROR "Error when compile GPU and XPU at the same time")
endif()
-if (WITH_GPU AND WITH_ASCEND)
+if (WITH_GPU AND WITH_ASCEND)
message(FATAL_ERROR "Error when compile GPU and ASCEND at the same time")
endif()
-# cmake 3.12, 3.13, 3.14 will append gcc link options to nvcc, and nvcc doesn't recognize them.
-if(WITH_GPU AND (${CMAKE_VERSION} VERSION_GREATER_EQUAL 3.12) AND (${CMAKE_VERSION} VERSION_LESS 3.15))
- message(FATAL_ERROR "cmake ${CMAKE_VERSION} is not supported when WITH_GPU=ON because of bug https://cmake.org/pipermail/cmake/2018-September/068195.html. "
- "You can use cmake 3.16 (recommended), 3.10, 3.11, 3.15 or 3.17. Please refer to the install document: https://cmake.org/install/")
+if (WITH_GPU AND WITH_ROCM)
+ message(FATAL_ERROR "Error when compile CUDA and ROCM at the same time")
endif()
if(WITH_GPU AND NOT APPLE)
@@ -61,6 +66,10 @@ if(WITH_MUSL)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-error=deprecated-declarations -Wno-deprecated-declarations -Wno-error=pessimizing-move -Wno-error=deprecated-copy")
endif()
+if(WITH_ASCEND_CL AND NOT WITH_ASCEND_CXX11)
+ set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D_GLIBCXX_USE_CXX11_ABI=0")
+endif()
+
if(WIN32)
option(MSVC_STATIC_CRT "use static C Runtime library by default" ON)
@@ -72,6 +81,13 @@ if(WIN32)
set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} /bigobj")
set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} /bigobj")
+ if("${CMAKE_GENERATOR}" STREQUAL "Ninja")
+ set(CMAKE_C_FLAGS_DEBUG "${CMAKE_C_FLAGS_DEBUG} /Zc:inline")
+ set(CMAKE_C_FLAGS_RELEASE "${CMAKE_C_FLAGS_RELEASE} /Zc:inline")
+ set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} /Zc:inline")
+ set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} /Zc:inline")
+ endif()
+
if (MSVC_STATIC_CRT)
message(STATUS "Use static C runtime time, refer to https://docs.microsoft.com/en-us/cpp/c-runtime-library/crt-library-features?view=vs-2019")
set(CMAKE_C_FLAGS_DEBUG "${CMAKE_C_FLAGS_DEBUG} /MTd")
@@ -89,8 +105,8 @@ if(WIN32)
endforeach(flag_var)
endif()
- # NOTE(Avin0323): Less parallel count result in faster compilation.
math(EXPR PROCESS_MAX "${CPU_CORES} * 2 / 3")
+
# windows build turn off warnings, use parallel compiling.
foreach(flag_var
CMAKE_CXX_FLAGS CMAKE_CXX_FLAGS_DEBUG CMAKE_CXX_FLAGS_RELEASE
@@ -98,7 +114,10 @@ if(WIN32)
CMAKE_C_FLAGS CMAKE_C_FLAGS_DEBUG CMAKE_C_FLAGS_RELEASE
CMAKE_C_FLAGS_MINSIZEREL CMAKE_C_FLAGS_RELWITHDEBINFO)
string(REGEX REPLACE "/W[1-4]" " /W0 " ${flag_var} "${${flag_var}}")
- set(${flag_var} "${${flag_var}} /MP${PROCESS_MAX}")
+ # NOTE(zhouwei25): GPU compile have too high memory utilization when parallel compiling
+ if(NOT WITH_GPU)
+ set(${flag_var} "${${flag_var}} /MP${PROCESS_MAX}")
+ endif()
endforeach(flag_var)
foreach(flag_var CMAKE_CXX_FLAGS CMAKE_C_FLAGS)
set(${flag_var} "${${flag_var}} /w")
@@ -116,6 +135,13 @@ if(WIN32)
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} /wd4068 /wd4129 /wd4244 /wd4267 /wd4297 /wd4530 /wd4577 /wd4819 /wd4838")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /wd4068 /wd4129 /wd4244 /wd4267 /wd4297 /wd4530 /wd4577 /wd4819 /wd4838")
+ foreach(flag_var CMAKE_SHARED_LINKER_FLAGS CMAKE_STATIC_LINKER_FLAGS CMAKE_EXE_LINKER_FLAGS CMAKE_LINKER_FLAGS)
+ set(${flag_var} "${${flag_var}} /ignore:4049 /ignore:4217 /ignore:4006 /ignore:4221")
+ if(MSVC_STATIC_CRT)
+ set(${flag_var} "${${flag_var}} /NODEFAULTLIB:MSVCRT.LIB")
+ endif()
+ endforeach(flag_var)
+
if (WITH_WIN_DUMP_DBG)
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} /Zi")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /Zi")
@@ -153,8 +179,6 @@ option(WITH_DISTRIBUTE "Compile with distributed support" OFF)
option(WITH_BRPC_RDMA "Use brpc rdma as the rpc protocal" OFF)
option(ON_INFER "Turn on inference optimization and inference-lib generation" OFF)
################################ Internal Configurations #######################################
-option(WITH_ROCM "Compile PaddlePaddle with ROCM platform" OFF)
-option(WITH_RCCL "Compile PaddlePaddle with RCCL support" OFF)
option(WITH_NV_JETSON "Compile PaddlePaddle with NV JETSON" OFF)
option(WITH_PROFILER "Compile PaddlePaddle with GPU profiler and gperftools" OFF)
option(WITH_COVERAGE "Compile PaddlePaddle with code coverage" OFF)
@@ -165,14 +189,15 @@ option(WITH_PSLIB "Compile with pslib support" OFF)
option(WITH_BOX_PS "Compile with box_ps support" OFF)
option(WITH_XBYAK "Compile with xbyak support" ON)
option(WITH_CONTRIB "Compile the third-party contributation" OFF)
-option(WITH_GRPC "Use grpc as the default rpc framework" ${WITH_DISTRIBUTE})
option(WITH_PSCORE "Compile with parameter server support" ${WITH_DISTRIBUTE})
+option(WITH_HETERPS "Compile with heterps" OFF})
option(WITH_INFERENCE_API_TEST "Test fluid inference C++ high-level api interface" OFF)
option(PY_VERSION "Compile PaddlePaddle with python3 support" ${PY_VERSION})
option(WITH_DGC "Use DGC(Deep Gradient Compression) or not" ${WITH_DISTRIBUTE})
option(SANITIZER_TYPE "Choose the type of sanitizer, options are: Address, Leak, Memory, Thread, Undefined" OFF)
option(WITH_LITE "Compile Paddle Fluid with Lite Engine" OFF)
option(WITH_NCCL "Compile PaddlePaddle with NCCL support" ON)
+option(WITH_RCCL "Compile PaddlePaddle with RCCL support" ON)
option(WITH_XPU_BKCL "Compile PaddlePaddle with BAIDU KUNLUN XPU BKCL" OFF)
option(WITH_CRYPTO "Compile PaddlePaddle with crypto support" ON)
option(WITH_ARM "Compile PaddlePaddle with arm support" OFF)
@@ -180,6 +205,7 @@ option(WITH_SW "Compile PaddlePaddle with sw support" OFF)
option(WITH_MIPS "Compile PaddlePaddle with mips support" OFF)
option(WITH_MUSL "Compile with musl libc instead of gblic" OFF)
option(WITH_UNITY_BUILD "Compile with UnityBuild mode" OFF)
+option(WITH_STRIP "Strip so files of Whl packages" OFF)
# PY_VERSION
if(NOT PY_VERSION)
@@ -240,9 +266,6 @@ endif()
if(WITH_BRPC_RDMA)
message(STATUS "Use brpc with rdma.")
- if(WITH_GRPC)
- message(FATAL_ERROR "Can't use grpc with brpc rdma.")
- endif()
if(NOT WITH_DISTRIBUTE)
message(FATAL_ERROR "Can't use brpc rdma in no distribute env.")
endif()
@@ -290,9 +313,9 @@ endif(WITH_ROCM)
if (NOT WITH_ROCM AND WITH_RCCL)
MESSAGE(WARNING
- "Disable RCCL when compiling without GPU. Force WITH_RCCL=OFF.")
- set(WITH_NCCL OFF CACHE STRING
- "Disable RCCL when compiling without GPU" FORCE)
+ "Disable RCCL when compiling without ROCM. Force WITH_RCCL=OFF.")
+ set(WITH_RCCL OFF CACHE STRING
+ "Disable RCCL when compiling without ROCM" FORCE)
endif()
if(WITH_RCCL)
@@ -330,6 +353,11 @@ if (WITH_MIPS)
add_definitions(-DPADDLE_WITH_MIPS)
endif()
+if (WITH_HETERPS)
+ if (CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 7.0)
+ set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -faligned-new")
+ endif()
+endif()
set(PADDLE_PYTHON_BUILD_DIR "${CMAKE_CURRENT_BINARY_DIR}/python/build")
set(CMAKE_CXX_FLAGS_RELWITHDEBINFO "-O3 -g -DNDEBUG")
@@ -347,6 +375,13 @@ else()
message(WARNING "On inference mode, will take place some specific optimization. Turn on the ON_INFER flag when building inference_lib only.")
endif()
+if(WITH_STRIP)
+ find_program(STRIP_PATH strip)
+ if(NOT STRIP_PATH OR NOT LINUX)
+ set(WITH_STRIP OFF CACHE STRING "Command strip is only used on Linux when it exists." FORCE)
+ endif()
+endif()
+
add_subdirectory(paddle)
if(WITH_PYTHON)
add_subdirectory(python)
diff --git a/README.md b/README.md
index e8a7013d0b4432bc871843b83cf19494ca870cbc..8b437e4115abe80073866f52f3d7e387e2a554d3 100644
--- a/README.md
+++ b/README.md
@@ -1,5 +1,4 @@
-
-
+
diff --git a/cmake/configure.cmake b/cmake/configure.cmake
index 9c1bd52e7fb7dfad5f6dc36d850468bf69ee92cd..e7f125269be1f5e015c6cf015489c312538ca4ba 100644
--- a/cmake/configure.cmake
+++ b/cmake/configure.cmake
@@ -82,6 +82,10 @@ if(WITH_ASCEND)
add_definitions(-DPADDLE_WITH_ASCEND)
endif()
+if(WITH_ASCEND_CL)
+ add_definitions(-DPADDLE_WITH_ASCEND_CL)
+endif()
+
if(WITH_XPU)
message(STATUS "Compile with XPU!")
add_definitions(-DPADDLE_WITH_XPU)
@@ -93,13 +97,18 @@ if(WITH_GPU)
FIND_PACKAGE(CUDA REQUIRED)
- if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_LESS 7)
- message(FATAL_ERROR "Paddle needs CUDA >= 7.0 to compile")
+ if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_LESS 10.1)
+ message(FATAL_ERROR "Paddle needs CUDA >= 10.1 to compile")
endif()
if(NOT CUDNN_FOUND)
message(FATAL_ERROR "Paddle needs cudnn to compile")
endif()
+
+ if(${CUDNN_MAJOR_VERSION} VERSION_LESS 7)
+ message(FATAL_ERROR "Paddle needs CUDNN >= 7.0 to compile")
+ endif()
+
if(CUPTI_FOUND)
include_directories(${CUPTI_INCLUDE_DIR})
add_definitions(-DPADDLE_WITH_CUPTI)
@@ -164,10 +173,9 @@ if(WITH_PSCORE)
add_definitions(-DPADDLE_WITH_PSCORE)
endif()
-
-if(WITH_GRPC)
- add_definitions(-DPADDLE_WITH_GRPC)
-endif(WITH_GRPC)
+if(WITH_HETERPS)
+ add_definitions(-DPADDLE_WITH_HETERPS)
+endif()
if(WITH_BRPC_RDMA)
add_definitions(-DPADDLE_WITH_BRPC_RDMA)
diff --git a/cmake/cuda.cmake b/cmake/cuda.cmake
index 2f4f5449f482d71a2a27957af4b5f17601ab634f..7f2addb02d36ddf85cd08542cc5baab31d495bc5 100644
--- a/cmake/cuda.cmake
+++ b/cmake/cuda.cmake
@@ -6,15 +6,9 @@ endif()
if (WITH_NV_JETSON)
add_definitions(-DWITH_NV_JETSON)
set(paddle_known_gpu_archs "53 62 72")
- set(paddle_known_gpu_archs7 "53")
- set(paddle_known_gpu_archs8 "53 62")
- set(paddle_known_gpu_archs9 "53 62")
set(paddle_known_gpu_archs10 "53 62 72")
else()
- set(paddle_known_gpu_archs "30 35 50 52 60 61 70")
- set(paddle_known_gpu_archs7 "30 35 50 52")
- set(paddle_known_gpu_archs8 "30 35 50 52 60 61")
- set(paddle_known_gpu_archs9 "30 35 50 52 60 61 70")
+ set(paddle_known_gpu_archs "35 50 52 60 61 70 75 80")
set(paddle_known_gpu_archs10 "35 50 52 60 61 70 75")
set(paddle_known_gpu_archs11 "52 60 61 70 75 80")
endif()
@@ -74,7 +68,7 @@ endfunction()
# select_nvcc_arch_flags(out_variable)
function(select_nvcc_arch_flags out_variable)
# List of arch names
- set(archs_names "Kepler" "Maxwell" "Pascal" "Volta" "Turing" "All" "Manual")
+ set(archs_names "Kepler" "Maxwell" "Pascal" "Volta" "Turing" "Ampere" "All" "Manual")
set(archs_name_default "Auto")
list(APPEND archs_names "Auto")
@@ -91,7 +85,7 @@ function(select_nvcc_arch_flags out_variable)
if(${CUDA_ARCH_NAME} STREQUAL "Manual")
set(CUDA_ARCH_BIN ${paddle_known_gpu_archs} CACHE STRING "Specify 'real' GPU architectures to build binaries for, BIN(PTX) format is supported")
- set(CUDA_ARCH_PTX "50" CACHE STRING "Specify 'virtual' PTX architectures to build PTX intermediate code for")
+ set(CUDA_ARCH_PTX "" CACHE STRING "Specify 'virtual' PTX architectures to build PTX intermediate code for")
mark_as_advanced(CUDA_ARCH_BIN CUDA_ARCH_PTX)
else()
unset(CUDA_ARCH_BIN CACHE)
@@ -108,6 +102,8 @@ function(select_nvcc_arch_flags out_variable)
set(cuda_arch_bin "70")
elseif(${CUDA_ARCH_NAME} STREQUAL "Turing")
set(cuda_arch_bin "75")
+ elseif(${CUDA_ARCH_NAME} STREQUAL "Ampere")
+ set(cuda_arch_bin "80")
elseif(${CUDA_ARCH_NAME} STREQUAL "All")
set(cuda_arch_bin ${paddle_known_gpu_archs})
elseif(${CUDA_ARCH_NAME} STREQUAL "Auto")
@@ -158,31 +154,21 @@ function(select_nvcc_arch_flags out_variable)
endfunction()
message(STATUS "CUDA detected: " ${CMAKE_CUDA_COMPILER_VERSION})
-if (${CMAKE_CUDA_COMPILER_VERSION} LESS 7.0)
- set(paddle_known_gpu_archs ${paddle_known_gpu_archs})
-elseif (${CMAKE_CUDA_COMPILER_VERSION} LESS 8.0) # CUDA 7.x
- set(paddle_known_gpu_archs ${paddle_known_gpu_archs7})
- set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -D_MWAITXINTRIN_H_INCLUDED")
- set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -D__STRICT_ANSI__")
-elseif (${CMAKE_CUDA_COMPILER_VERSION} LESS 9.0) # CUDA 8.x
- set(paddle_known_gpu_archs ${paddle_known_gpu_archs8})
+if (${CMAKE_CUDA_COMPILER_VERSION} LESS 11.0) # CUDA 10.x
+ set(paddle_known_gpu_archs ${paddle_known_gpu_archs10})
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -D_MWAITXINTRIN_H_INCLUDED")
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -D__STRICT_ANSI__")
- # CUDA 8 may complain that sm_20 is no longer supported. Suppress the
- # warning for now.
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -Wno-deprecated-gpu-targets")
-elseif (${CMAKE_CUDA_COMPILER_VERSION} LESS 10.0) # CUDA 9.x
- set(paddle_known_gpu_archs ${paddle_known_gpu_archs9})
- set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -D_MWAITXINTRIN_H_INCLUDED")
- set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -D__STRICT_ANSI__")
-elseif (${CMAKE_CUDA_COMPILER_VERSION} LESS 11.0) # CUDA 10.x
- set(paddle_known_gpu_archs ${paddle_known_gpu_archs10})
+elseif (${CMAKE_CUDA_COMPILER_VERSION} LESS 11.2) # CUDA 11.0/11.1
+ set(paddle_known_gpu_archs ${paddle_known_gpu_archs11})
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -D_MWAITXINTRIN_H_INCLUDED")
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -D__STRICT_ANSI__")
-elseif (${CMAKE_CUDA_COMPILER_VERSION} LESS 12.0) # CUDA 11.x
- set(paddle_known_gpu_archs ${paddle_known_gpu_archs11})
+ set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -Wno-deprecated-gpu-targets")
+elseif (${CMAKE_CUDA_COMPILER_VERSION} LESS 12.0) # CUDA 11.2+
+ set(paddle_known_gpu_archs "${paddle_known_gpu_archs11} 86")
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -D_MWAITXINTRIN_H_INCLUDED")
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -D__STRICT_ANSI__")
+ set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -Wno-deprecated-gpu-targets")
endif()
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} LESS 10.0)
@@ -198,14 +184,11 @@ select_nvcc_arch_flags(NVCC_FLAGS_EXTRA)
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} ${NVCC_FLAGS_EXTRA}")
message(STATUS "NVCC_FLAGS_EXTRA: ${NVCC_FLAGS_EXTRA}")
-# Set C++11 support
+# Set C++14 support
set(CUDA_PROPAGATE_HOST_FLAGS OFF)
# Release/Debug flags set by cmake. Such as -O3 -g -DNDEBUG etc.
# So, don't set these flags here.
-if (NOT WIN32) # windows msvc2015 support c++11 natively.
- # -std=c++11 -fPIC not recoginize by msvc, -Xcompiler will be added by cmake.
- set(CMAKE_CUDA_STANDARD 11)
-endif(NOT WIN32)
+set(CMAKE_CUDA_STANDARD 14)
# (Note) For windows, if delete /W[1-4], /W1 will be added defaultly and conflic with -w
# So replace /W[1-4] with /W0
diff --git a/cmake/cudnn.cmake b/cmake/cudnn.cmake
index d8d8f634e76b6bf05d4936921ce37c889a4bdc7c..c82847100abefa6fcbaf1367965699413aadcadb 100644
--- a/cmake/cudnn.cmake
+++ b/cmake/cudnn.cmake
@@ -94,7 +94,7 @@ macro(find_cudnn_version cudnn_header_file)
"${CUDNN_MAJOR_VERSION} * 1000 +
${CUDNN_MINOR_VERSION} * 100 + ${CUDNN_PATCHLEVEL_VERSION}")
message(STATUS "Current cuDNN header is ${cudnn_header_file} "
- "Current cuDNN version is v${CUDNN_MAJOR_VERSION}.${CUDNN_MINOR_VERSION}. ")
+ "Current cuDNN version is v${CUDNN_MAJOR_VERSION}.${CUDNN_MINOR_VERSION}.${CUDNN_PATCHLEVEL_VERSION}. ")
endif()
endif()
endmacro()
diff --git a/cmake/external/ascend.cmake b/cmake/external/ascend.cmake
index bcf0c0a0646fc386f41c4b1f35ba773d6a1adb6f..414b2a54be0342b3ef76d5e3a553577cb5f3e4be 100644
--- a/cmake/external/ascend.cmake
+++ b/cmake/external/ascend.cmake
@@ -12,50 +12,78 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-INCLUDE(ExternalProject)
-
-SET(ASCEND_PROJECT "extern_ascend")
-IF((NOT DEFINED ASCEND_VER) OR (NOT DEFINED ASCEND_URL))
- MESSAGE(STATUS "use pre defined download url")
- SET(ASCEND_VER "0.1.1" CACHE STRING "" FORCE)
- SET(ASCEND_NAME "ascend" CACHE STRING "" FORCE)
- SET(ASCEND_URL "http://paddle-ascend.bj.bcebos.com/ascend.tar.gz" CACHE STRING "" FORCE)
-ENDIF()
-MESSAGE(STATUS "ASCEND_NAME: ${ASCEND_NAME}, ASCEND_URL: ${ASCEND_URL}")
-SET(ASCEND_SOURCE_DIR "${THIRD_PARTY_PATH}/ascend")
-SET(ASCEND_DOWNLOAD_DIR "${ASCEND_SOURCE_DIR}/src/${ASCEND_PROJECT}")
-SET(ASCEND_DST_DIR "ascend")
-SET(ASCEND_INSTALL_ROOT "${THIRD_PARTY_PATH}/install")
-SET(ASCEND_INSTALL_DIR ${ASCEND_INSTALL_ROOT}/${ASCEND_DST_DIR})
-SET(ASCEND_ROOT ${ASCEND_INSTALL_DIR})
-SET(ASCEND_INC_DIR ${ASCEND_ROOT}/include)
-SET(ASCEND_LIB_DIR ${ASCEND_ROOT}/lib)
-SET(ASCEND_LIB ${ASCEND_LIB_DIR}/libge_runner.so)
-SET(ASCEND_GRAPH_LIB ${ASCEND_LIB_DIR}/libgraph.so)
-SET(CMAKE_INSTALL_RPATH "${CMAKE_INSTALL_RPATH}" "${ASCEND_ROOT}/lib")
-
-INCLUDE_DIRECTORIES(${ASCEND_INC_DIR})
-FILE(WRITE ${ASCEND_DOWNLOAD_DIR}/CMakeLists.txt
- "PROJECT(ASCEND)\n"
- "cmake_minimum_required(VERSION 3.0)\n"
- "install(DIRECTORY ${ASCEND_NAME}/include ${ASCEND_NAME}/lib \n"
- " DESTINATION ${ASCEND_DST_DIR})\n")
-ExternalProject_Add(
- ${ASCEND_PROJECT}
- ${EXTERNAL_PROJECT_LOG_ARGS}
- PREFIX ${ASCEND_SOURCE_DIR}
- DOWNLOAD_DIR ${ASCEND_DOWNLOAD_DIR}
- DOWNLOAD_COMMAND wget --no-check-certificate ${ASCEND_URL} -c -q -O ${ASCEND_NAME}.tar.gz
- && tar zxvf ${ASCEND_NAME}.tar.gz
- DOWNLOAD_NO_PROGRESS 1
- UPDATE_COMMAND ""
- CMAKE_ARGS -DCMAKE_INSTALL_PREFIX=${ASCEND_INSTALL_ROOT}
- CMAKE_CACHE_ARGS -DCMAKE_INSTALL_PREFIX:PATH=${ASCEND_INSTALL_ROOT}
-)
-ADD_LIBRARY(ascend SHARED IMPORTED GLOBAL)
-SET_PROPERTY(TARGET ascend PROPERTY IMPORTED_LOCATION ${ASCEND_LIB})
-
-ADD_LIBRARY(ascend_graph SHARED IMPORTED GLOBAL)
-SET_PROPERTY(TARGET ascend_graph PROPERTY IMPORTED_LOCATION ${ASCEND_GRAPH_LIB})
-ADD_DEPENDENCIES(ascend ascend_graph ${ASCEND_PROJECT})
+#NOTE: Logic is from
+# https://github.com/mindspore-ai/graphengine/blob/master/CMakeLists.txt
+if(DEFINED ENV{ASCEND_CUSTOM_PATH})
+ set(ASCEND_DIR $ENV{ASCEND_CUSTOM_PATH})
+else()
+ set(ASCEND_DIR /usr/local/Ascend)
+endif()
+
+if(EXISTS ${ASCEND_DIR}/ascend-toolkit/latest/fwkacllib/include/graph/ascend_string.h)
+ # It means CANN 20.2 +
+ add_definitions(-DPADDLE_WITH_ASCEND_STRING)
+endif()
+
+
+if(WITH_ASCEND OR WITH_ASCEND_CL)
+ set(ASCEND_DRIVER_DIR ${ASCEND_DIR}/driver/lib64)
+ set(ASCEND_DRIVER_COMMON_DIR ${ASCEND_DIR}/driver/lib64/common)
+ set(ASCEND_DRIVER_SHARE_DIR ${ASCEND_DIR}/driver/lib64/share)
+ set(ASCEND_RUNTIME_DIR ${ASCEND_DIR}/fwkacllib/lib64)
+ set(ASCEND_ATC_DIR ${ASCEND_DIR}/atc/lib64)
+ set(ASCEND_ACL_DIR ${ASCEND_DIR}/acllib/lib64)
+ set(STATIC_ACL_LIB ${ASCEND_ACL_DIR})
+
+ set(ASCEND_MS_RUNTIME_PATH ${ASCEND_RUNTIME_DIR} ${ASCEND_ACL_DIR} ${ASCEND_ATC_DIR})
+ set(ASCEND_MS_DRIVER_PATH ${ASCEND_DRIVER_DIR} ${ASCEND_DRIVER_COMMON_DIR})
+ set(ATLAS_RUNTIME_DIR ${ASCEND_DIR}/ascend-toolkit/latest/fwkacllib/lib64)
+ set(ATLAS_RUNTIME_INC_DIR ${ASCEND_DIR}/ascend-toolkit/latest/fwkacllib/include)
+ set(ATLAS_ACL_DIR ${ASCEND_DIR}/ascend-toolkit/latest/acllib/lib64)
+ set(ATLAS_ATC_DIR ${ASCEND_DIR}/ascend-toolkit/latest/atc/lib64)
+ set(ATLAS_MS_RUNTIME_PATH ${ATLAS_RUNTIME_DIR} ${ATLAS_ACL_DIR} ${ATLAS_ATC_DIR})
+
+ set(atlas_graph_lib ${ATLAS_RUNTIME_DIR}/libgraph.so)
+ set(atlas_ge_runner_lib ${ATLAS_RUNTIME_DIR}/libge_runner.so)
+ set(atlas_acl_lib ${ATLAS_RUNTIME_DIR}/libascendcl.so)
+ INCLUDE_DIRECTORIES(${ATLAS_RUNTIME_INC_DIR})
+
+
+ ADD_LIBRARY(ascend_ge SHARED IMPORTED GLOBAL)
+ SET_PROPERTY(TARGET ascend_ge PROPERTY IMPORTED_LOCATION ${atlas_ge_runner_lib})
+
+ ADD_LIBRARY(ascend_graph SHARED IMPORTED GLOBAL)
+ SET_PROPERTY(TARGET ascend_graph PROPERTY IMPORTED_LOCATION ${atlas_graph_lib})
+
+ ADD_LIBRARY(atlas_acl SHARED IMPORTED GLOBAL)
+ SET_PROPERTY(TARGET atlas_acl PROPERTY IMPORTED_LOCATION ${atlas_acl_lib})
+
+ add_custom_target(extern_ascend DEPENDS ascend_ge ascend_graph atlas_acl)
+endif()
+
+if(WITH_ASCEND_CL)
+ set(ASCEND_CL_DIR ${ASCEND_DIR}/ascend-toolkit/latest/fwkacllib/lib64)
+
+ set(ascend_hccl_lib ${ASCEND_CL_DIR}/libhccl.so)
+ set(ascendcl_lib ${ASCEND_CL_DIR}/libascendcl.so)
+ set(acl_op_compiler_lib ${ASCEND_CL_DIR}/libacl_op_compiler.so)
+ set(FWKACLLIB_INC_DIR ${ASCEND_DIR}/ascend-toolkit/latest/fwkacllib/include)
+ set(ACLLIB_INC_DIR ${ASCEND_DIR}/ascend-toolkit/latest/acllib/include)
+
+ message(STATUS "FWKACLLIB_INC_DIR ${FWKACLLIB_INC_DIR}")
+ message(STATUS "ASCEND_CL_DIR ${ASCEND_CL_DIR}")
+ INCLUDE_DIRECTORIES(${FWKACLLIB_INC_DIR})
+ INCLUDE_DIRECTORIES(${ACLLIB_INC_DIR})
+
+ ADD_LIBRARY(ascendcl SHARED IMPORTED GLOBAL)
+ SET_PROPERTY(TARGET ascendcl PROPERTY IMPORTED_LOCATION ${ascendcl_lib})
+
+ ADD_LIBRARY(ascend_hccl SHARED IMPORTED GLOBAL)
+ SET_PROPERTY(TARGET ascend_hccl PROPERTY IMPORTED_LOCATION ${ascend_hccl_lib})
+
+ ADD_LIBRARY(acl_op_compiler SHARED IMPORTED GLOBAL)
+ SET_PROPERTY(TARGET acl_op_compiler PROPERTY IMPORTED_LOCATION ${acl_op_compiler_lib})
+ add_custom_target(extern_ascend_cl DEPENDS ascendcl acl_op_compiler)
+
+endif()
diff --git a/cmake/external/brpc.cmake b/cmake/external/brpc.cmake
index 0eb590c42d0cb73ccb252430bc3e27312b0bddf9..2d72b6eb56deaa2547051756afc075a100aeb251 100644
--- a/cmake/external/brpc.cmake
+++ b/cmake/external/brpc.cmake
@@ -39,9 +39,9 @@ set(prefix_path "${THIRD_PARTY_PATH}/install/gflags|${THIRD_PARTY_PATH}/install/
ExternalProject_Add(
extern_brpc
${EXTERNAL_PROJECT_LOG_ARGS}
- # TODO(gongwb): change to de newst repo when they changed.
+ # TODO(gongwb): change to de newst repo when they changed
GIT_REPOSITORY "https://github.com/wangjiawei04/brpc"
- GIT_TAG "6d79e0b17f25107c35b705ea58d888083f59ff47"
+ GIT_TAG "e203afb794caf027da0f1e0776443e7d20c0c28e"
PREFIX ${BRPC_SOURCES_DIR}
UPDATE_COMMAND ""
CMAKE_ARGS -DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER}
diff --git a/cmake/external/eigen.cmake b/cmake/external/eigen.cmake
index 5a755a816c332a2517ed61caa94d647afd557aae..aa471002eacb6a61a9cf835f293a86a75d87db8f 100644
--- a/cmake/external/eigen.cmake
+++ b/cmake/external/eigen.cmake
@@ -14,11 +14,11 @@
include(ExternalProject)
-# update eigen to the commit id 4da2c6b1 on 03/19/2020
+# update eigen to the commit id f612df27 on 03/16/2021
set(EIGEN_PREFIX_DIR ${THIRD_PARTY_PATH}/eigen3)
set(EIGEN_SOURCE_DIR ${THIRD_PARTY_PATH}/eigen3/src/extern_eigen3)
set(EIGEN_REPOSITORY https://gitlab.com/libeigen/eigen.git)
-set(EIGEN_TAG 4da2c6b1974827b1999bab652a3d4703e1992d26)
+set(EIGEN_TAG f612df273689a19d25b45ca4f8269463207c4fee)
cache_third_party(extern_eigen3
REPOSITORY ${EIGEN_REPOSITORY}
@@ -27,47 +27,15 @@ cache_third_party(extern_eigen3
if(WIN32)
add_definitions(-DEIGEN_STRONG_INLINE=inline)
- file(TO_NATIVE_PATH ${PADDLE_SOURCE_DIR}/patches/eigen/Half.h native_src)
- file(TO_NATIVE_PATH ${EIGEN_SOURCE_DIR}/Eigen/src/Core/arch/CUDA/Half.h native_dst)
- # For Windows
- # which will cause a compilation error in Tensor:74:
- # "can not open file 'unistd.h'"
- # so use following patch to solve compilation error On Windows.
- file(TO_NATIVE_PATH ${PADDLE_SOURCE_DIR}/patches/eigen/Tensor native_src2)
- file(TO_NATIVE_PATH ${EIGEN_SOURCE_DIR}/unsupported/Eigen/CXX11/Tensor native_dst2)
- # For VS2015
- # which will cause a compilation error in TensorBlock.h:1028:
- # "syntax error"
- # so use following patch to solve compilation error On Windows.
- file(TO_NATIVE_PATH ${PADDLE_SOURCE_DIR}/patches/eigen/TensorBlock.h native_src3)
- file(TO_NATIVE_PATH ${EIGEN_SOURCE_DIR}/unsupported/Eigen/CXX11/src/Tensor/TensorBlock.h native_dst3)
- set(EIGEN_PATCH_COMMAND copy ${native_src} ${native_dst} /Y && copy ${native_src2} ${native_dst2} /Y && copy ${native_src3} ${native_dst3} /Y)
elseif(LINUX)
- # For gxx=4.8, __GXX_ABI_VERSION is less than 1004
- # which will cause a compilation error in Geometry_SSE.h:38:
- # "no matching function for call to 'pmul(Eigen::internal::Packet4f&, __m128)"
- # refer to: https://gitlab.com/libeigen/eigen/-/blob/4da2c6b1974827b1999bab652a3d4703e1992d26/Eigen/src/Core/arch/SSE/PacketMath.h#L33-60
- # add -fabi-version=4 could avoid above error, but will cause "double free corruption" when compile with gcc8
- # so use following patch to solve compilation error with different version of gcc.
- file(TO_NATIVE_PATH ${PADDLE_SOURCE_DIR}/patches/eigen/Geometry_SSE.h native_src1)
- file(TO_NATIVE_PATH ${EIGEN_SOURCE_DIR}/Eigen/src/Geometry/arch/Geometry_SSE.h native_dst1)
- # The compiler fully support const expressions since c++14,
- # but Eigen use some const expressions such as std::max and std::min, which are not supported in c++11
- # add patch to avoid compilation error in c++11
- file(TO_NATIVE_PATH ${PADDLE_SOURCE_DIR}/patches/eigen/MathFunctions.h native_src2)
- file(TO_NATIVE_PATH ${EIGEN_SOURCE_DIR}/Eigen/src/Core/MathFunctions.h native_dst2)
if(WITH_ROCM)
# For HIPCC Eigen::internal::device::numeric_limits is not EIGEN_DEVICE_FUNC
# which will cause compiler error of using __host__ funciont in __host__ __device__
- file(TO_NATIVE_PATH ${PADDLE_SOURCE_DIR}/patches/eigen/Meta.h native_src3)
- file(TO_NATIVE_PATH ${EIGEN_SOURCE_DIR}/Eigen/src/Core/util/Meta.h native_dst3)
- # For HIPCC Eigen::internal::scalar_sum_op is not EIGEN_DEVICE_FUNC
- # which will cause compiler error of using __host__ funciont in __host__ __device__
- file(TO_NATIVE_PATH ${PADDLE_SOURCE_DIR}/patches/eigen/BinaryFunctors.h native_src4)
- file(TO_NATIVE_PATH ${EIGEN_SOURCE_DIR}/Eigen/src/Core/functors/BinaryFunctors.h native_dst4)
- set(EIGEN_PATCH_COMMAND cp ${native_src1} ${native_dst1} && cp ${native_src2} ${native_dst2} && cp ${native_src3} ${native_dst3} && cp ${native_src4} ${native_dst4})
- else()
- set(EIGEN_PATCH_COMMAND cp ${native_src1} ${native_dst1} && cp ${native_src2} ${native_dst2})
+ file(TO_NATIVE_PATH ${PADDLE_SOURCE_DIR}/patches/eigen/Meta.h native_src)
+ file(TO_NATIVE_PATH ${EIGEN_SOURCE_DIR}/Eigen/src/Core/util/Meta.h native_dst)
+ file(TO_NATIVE_PATH ${PADDLE_SOURCE_DIR}/patches/eigen/TensorReductionGpu.h native_src1)
+ file(TO_NATIVE_PATH ${EIGEN_SOURCE_DIR}/unsupported/Eigen/CXX11/src/Tensor/TensorReductionGpu.h native_dst1)
+ set(EIGEN_PATCH_COMMAND cp ${native_src} ${native_dst} && cp ${native_src1} ${native_dst1})
endif()
endif()
@@ -82,7 +50,7 @@ ExternalProject_Add(
PREFIX ${EIGEN_PREFIX_DIR}
SOURCE_DIR ${EIGEN_SOURCE_DIR}
UPDATE_COMMAND ""
- PATCH_COMMAND ${EIGEN_PATCH_COMMAND}
+ PATCH_COMMAND ${EIGEN_PATCH_COMMAND}
CONFIGURE_COMMAND ""
BUILD_COMMAND ""
INSTALL_COMMAND ""
diff --git a/cmake/external/gloo.cmake b/cmake/external/gloo.cmake
index ea7af315e1a690578bd16c89cc83a158dacca4cf..e8db13a694f5578e314dc1a7c95ed24ad88bad02 100644
--- a/cmake/external/gloo.cmake
+++ b/cmake/external/gloo.cmake
@@ -32,21 +32,39 @@ cache_third_party(extern_gloo
TAG ${GLOO_TAG}
DIR GLOO_SOURCE_DIR)
-ExternalProject_Add(
- extern_gloo
- ${EXTERNAL_PROJECT_LOG_ARGS}
- ${SHALLOW_CLONE}
- "${GLOO_DOWNLOAD_CMD}"
- PREFIX "${GLOO_PREFIX_DIR}"
- SOURCE_DIR "${GLOO_SOURCE_DIR}"
- UPDATE_COMMAND ""
- CONFIGURE_COMMAND ""
- BUILD_COMMAND mkdir -p ${GLOO_SOURCE_DIR}/build
- && cd ${GLOO_SOURCE_DIR}/build && cmake .. && make
- && mkdir -p ${GLOO_LIBRARY_DIR} ${GLOO_INCLUDE_DIR}/gloo
- INSTALL_COMMAND ${CMAKE_COMMAND} -E copy ${GLOO_SOURCE_DIR}/build/gloo/libgloo.a ${GLOO_LIBRARY_DIR}
- COMMAND ${CMAKE_COMMAND} -E copy_directory "${GLOO_SOURCE_DIR}/gloo/" "${GLOO_INCLUDE_DIR}/gloo"
-)
+ if(WITH_ASCEND OR WITH_ASCEND_CL)
+ ExternalProject_Add(
+ extern_gloo
+ ${EXTERNAL_PROJECT_LOG_ARGS}
+ ${SHALLOW_CLONE}
+ "${GLOO_DOWNLOAD_CMD}"
+ PREFIX "${GLOO_PREFIX_DIR}"
+ SOURCE_DIR "${GLOO_SOURCE_DIR}"
+ UPDATE_COMMAND ""
+ CONFIGURE_COMMAND ""
+ BUILD_COMMAND mkdir -p ${GLOO_SOURCE_DIR}/build
+ && cd ${GLOO_SOURCE_DIR}/build && cmake .. -DCMAKE_CXX_FLAGS=${CMAKE_CXX_FLAGS} && make
+ && mkdir -p ${GLOO_LIBRARY_DIR} ${GLOO_INCLUDE_DIR}/gloo
+ INSTALL_COMMAND ${CMAKE_COMMAND} -E copy ${GLOO_SOURCE_DIR}/build/gloo/libgloo.a ${GLOO_LIBRARY_DIR}
+ COMMAND ${CMAKE_COMMAND} -E copy_directory "${GLOO_SOURCE_DIR}/gloo/" "${GLOO_INCLUDE_DIR}/gloo"
+ )
+else()
+ ExternalProject_Add(
+ extern_gloo
+ ${EXTERNAL_PROJECT_LOG_ARGS}
+ ${SHALLOW_CLONE}
+ "${GLOO_DOWNLOAD_CMD}"
+ PREFIX "${GLOO_PREFIX_DIR}"
+ SOURCE_DIR "${GLOO_SOURCE_DIR}"
+ UPDATE_COMMAND ""
+ CONFIGURE_COMMAND ""
+ BUILD_COMMAND mkdir -p ${GLOO_SOURCE_DIR}/build
+ && cd ${GLOO_SOURCE_DIR}/build && cmake .. && make
+ && mkdir -p ${GLOO_LIBRARY_DIR} ${GLOO_INCLUDE_DIR}/gloo
+ INSTALL_COMMAND ${CMAKE_COMMAND} -E copy ${GLOO_SOURCE_DIR}/build/gloo/libgloo.a ${GLOO_LIBRARY_DIR}
+ COMMAND ${CMAKE_COMMAND} -E copy_directory "${GLOO_SOURCE_DIR}/gloo/" "${GLOO_INCLUDE_DIR}/gloo"
+ )
+endif()
ADD_LIBRARY(gloo STATIC IMPORTED GLOBAL)
diff --git a/cmake/external/grpc.cmake b/cmake/external/grpc.cmake
deleted file mode 100644
index 536e95c1dc2a4fe6545bd5d3147631aa26cdda98..0000000000000000000000000000000000000000
--- a/cmake/external/grpc.cmake
+++ /dev/null
@@ -1,77 +0,0 @@
-# Copyright (c) 2017 PaddlePaddle Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-#
-
-include (ExternalProject)
-
-SET(GRPC_SOURCES_DIR ${THIRD_PARTY_PATH}/grpc)
-SET(GRPC_INSTALL_DIR ${THIRD_PARTY_PATH}/install/grpc)
-SET(GRPC_INCLUDE_DIR "${GRPC_INSTALL_DIR}/include/" CACHE PATH "grpc include directory." FORCE)
-SET(GRPC_CPP_PLUGIN "${GRPC_INSTALL_DIR}/bin/grpc_cpp_plugin" CACHE FILEPATH "GRPC_CPP_PLUGIN" FORCE)
-
-include(ProcessorCount)
-ProcessorCount(NUM_OF_PROCESSOR)
-
-IF(APPLE)
- SET(BUILD_CMD make -n HAS_SYSTEM_PROTOBUF=false -s -j ${NUM_OF_PROCESSOR} static grpc_cpp_plugin | sed "s/-Werror//g" | sh)
- SET(GRPC_INSTALL_CMD make prefix=${GRPC_INSTALL_DIR} install)
-ELSE()
- SET(GRPC_CFLAGS "-Wno-error -std=c11 ${CLFAGS}")
- SET(GRPC_CXXFLAGS "-Wno-error -std=c++11 ${CXXFLAGS}")
- SET(BUILD_CMD make CFLAGS=${GRPC_CFLAGS} CXXFLAGS=${GRPC_CXXFLAGS} HAS_SYSTEM_PROTOBUF=false -s -j ${NUM_OF_PROCESSOR} static grpc_cpp_plugin)
- SET(GRPC_INSTALL_CMD make prefix=${GRPC_INSTALL_DIR} install CFLAGS=${GRPC_CFLAGS} CXXFLAGS=${GRPC_CXXFLAGS})
-ENDIF()
-
-# FIXME(wuyi): do not build zlib cares protobuf twice, find a way to build grpc with them
-ExternalProject_Add(
- extern_grpc
- DEPENDS protobuf zlib
- # NOTE(wuyi):
- # this package is generated by following steps:
- # 1. git clone -b v1.8.x https://github.com/grpc/grpc.git
- # 2. git submodule update --init
- # 3. keep only zlib, cares, protobuf, boringssl under "third_party",
- # checkout and clean other dirs under third_party
- # 4. remove .git, and package the directory.
- URL http://paddlepaddledeps.bj.bcebos.com/grpc-v1.10.x_paddle.tar.gz
- URL_MD5 f5442d137ddccee252e194b1bc90f98c
- PREFIX ${GRPC_SOURCES_DIR}
- UPDATE_COMMAND ""
- CONFIGURE_COMMAND ""
- BUILD_IN_SOURCE 1
- # NOTE(yuyang18):
- # Disable -Werror, otherwise the compile will fail in MacOS.
- # It seems that we cannot configure that by make command.
- # Just dry run make command and remove `-Werror`, then use a shell to run make commands
- BUILD_COMMAND ${BUILD_CMD}
- INSTALL_COMMAND ${GRPC_INSTALL_CMD}
-)
-
-ADD_LIBRARY(grpc++_unsecure STATIC IMPORTED GLOBAL)
-SET_PROPERTY(TARGET grpc++_unsecure PROPERTY IMPORTED_LOCATION
- "${GRPC_INSTALL_DIR}/lib/libgrpc++_unsecure.a")
-
-ADD_LIBRARY(grpc++ STATIC IMPORTED GLOBAL)
-SET_PROPERTY(TARGET grpc++ PROPERTY IMPORTED_LOCATION
- "${GRPC_INSTALL_DIR}/lib/libgrpc++.a")
-ADD_LIBRARY(gpr STATIC IMPORTED GLOBAL)
-SET_PROPERTY(TARGET gpr PROPERTY IMPORTED_LOCATION
- "${GRPC_INSTALL_DIR}/lib/libgpr.a")
-
-ADD_LIBRARY(grpc_unsecure STATIC IMPORTED GLOBAL)
-SET_PROPERTY(TARGET grpc_unsecure PROPERTY IMPORTED_LOCATION
- "${GRPC_INSTALL_DIR}/lib/libgrpc_unsecure.a")
-
-include_directories(${GRPC_INCLUDE_DIR})
-ADD_DEPENDENCIES(grpc++_unsecure extern_grpc)
diff --git a/cmake/external/mkldnn.cmake b/cmake/external/mkldnn.cmake
index 884219d8dd81f30e17f7a86380947262014e402a..fb1d4d9d56dcc6f38a86242b4d78b88ef31ddaa0 100644
--- a/cmake/external/mkldnn.cmake
+++ b/cmake/external/mkldnn.cmake
@@ -20,7 +20,7 @@ SET(MKLDNN_SOURCE_DIR ${THIRD_PARTY_PATH}/mkldnn/src/extern_mkldnn)
SET(MKLDNN_INSTALL_DIR ${THIRD_PARTY_PATH}/install/mkldnn)
SET(MKLDNN_INC_DIR "${MKLDNN_INSTALL_DIR}/include" CACHE PATH "mkldnn include directory." FORCE)
SET(MKLDNN_REPOSITORY ${GIT_URL}/oneapi-src/oneDNN.git)
-SET(MKLDNN_TAG 72efa005effb49595933e033cc732f215ef0445a)
+SET(MKLDNN_TAG f58682cd8bd0615f41d879f8afc8f1511ab42d24)
# Introduce variables:
# * CMAKE_INSTALL_LIBDIR
diff --git a/cmake/external/protobuf.cmake b/cmake/external/protobuf.cmake
index 40a27f506f3077a5a47289d20906f7c180681b65..c108c05368c915f6d4998d46713cda315dfb93ff 100644
--- a/cmake/external/protobuf.cmake
+++ b/cmake/external/protobuf.cmake
@@ -198,8 +198,16 @@ FUNCTION(build_protobuf TARGET_NAME BUILD_FOR_HOST)
"-Dprotobuf_MSVC_STATIC_RUNTIME=${MSVC_STATIC_CRT}")
ENDIF()
+if(WITH_ASCEND AND NOT WITH_ASCEND_CXX11)
+ SET(PROTOBUF_REPOSITORY https://gitee.com/tianjianhe/protobuf.git)
+ SET(PROTOBUF_TAG v3.8.0)
+elseif(WITH_ASCEND_CL AND NOT WITH_ASCEND_CXX11)
+ SET(PROTOBUF_REPOSITORY https://gitee.com/tianjianhe/protobuf.git)
+ SET(PROTOBUF_TAG v3.8.0)
+else()
SET(PROTOBUF_REPOSITORY ${GIT_URL}/protocolbuffers/protobuf.git)
SET(PROTOBUF_TAG 9f75c5aa851cd877fb0d93ccc31b8567a6706546)
+endif()
cache_third_party(${TARGET_NAME}
REPOSITORY ${PROTOBUF_REPOSITORY}
@@ -234,7 +242,11 @@ FUNCTION(build_protobuf TARGET_NAME BUILD_FOR_HOST)
)
ENDFUNCTION()
-SET(PROTOBUF_VERSION 3.1.0)
+if(WITH_ASCEND OR WITH_ASCEND_CL)
+ SET(PROTOBUF_VERSION 3.8.0)
+else()
+ SET(PROTOBUF_VERSION 3.1.0)
+endif()
IF(NOT PROTOBUF_FOUND)
build_protobuf(extern_protobuf FALSE)
diff --git a/cmake/external/threadpool.cmake b/cmake/external/threadpool.cmake
index 205e8d26d93ca1c25e5b59ecc3b063b4837db77b..f9cb3a9075a821025129c1f6acb479a4ad6ac95c 100644
--- a/cmake/external/threadpool.cmake
+++ b/cmake/external/threadpool.cmake
@@ -16,7 +16,11 @@ INCLUDE(ExternalProject)
SET(THREADPOOL_PREFIX_DIR ${THIRD_PARTY_PATH}/threadpool)
SET(THREADPOOL_SOURCE_DIR ${THIRD_PARTY_PATH}/threadpool/src/extern_threadpool)
-SET(THREADPOOL_REPOSITORY ${GIT_URL}/progschj/ThreadPool.git)
+if(WITH_ASCEND OR WITH_ASCEND_CL)
+ SET(THREADPOOL_REPOSITORY https://gitee.com/tianjianhe/ThreadPool.git)
+else()
+ SET(THREADPOOL_REPOSITORY ${GIT_URL}/progschj/ThreadPool.git)
+endif()
SET(THREADPOOL_TAG 9a42ec1329f259a5f4881a291db1dcb8f2ad9040)
cache_third_party(extern_threadpool
diff --git a/cmake/external/warpctc.cmake b/cmake/external/warpctc.cmake
index 0ee3e2116a94b68d528a475a453d1c31f0464cf4..c591a9391dfa5d3b5a452ffbb5a5d3199d387519 100644
--- a/cmake/external/warpctc.cmake
+++ b/cmake/external/warpctc.cmake
@@ -14,11 +14,17 @@
INCLUDE(ExternalProject)
+IF(WITH_ROCM)
+ add_definitions(-DWARPCTC_WITH_HIP)
+ENDIF()
+
SET(WARPCTC_PREFIX_DIR ${THIRD_PARTY_PATH}/warpctc)
SET(WARPCTC_SOURCE_DIR ${THIRD_PARTY_PATH}/warpctc/src/extern_warpctc)
SET(WARPCTC_INSTALL_DIR ${THIRD_PARTY_PATH}/install/warpctc)
+# in case of low internet speed
+#set(WARPCTC_REPOSITORY https://gitee.com/tianjianhe/warp-ctc.git)
set(WARPCTC_REPOSITORY ${GIT_URL}/baidu-research/warp-ctc.git)
-set(WARPCTC_TAG 95a461eddeabd51099ef059dcfada1117eb1bfb8)
+set(WARPCTC_TAG c690fc5755abbdbdc98ef78d51ec10a6748a8cd1)
SET(WARPCTC_INCLUDE_DIR "${WARPCTC_INSTALL_DIR}/include"
CACHE PATH "Warp-ctc Directory" FORCE)
@@ -37,38 +43,92 @@ cache_third_party(extern_warpctc
TAG ${WARPCTC_TAG}
DIR WARPCTC_SOURCE_DIR)
-ExternalProject_Add(
- extern_warpctc
- ${EXTERNAL_PROJECT_LOG_ARGS}
- ${SHALLOW_CLONE}
- "${WARPCTC_DOWNLOAD_CMD}"
- PREFIX ${WARPCTC_PREFIX_DIR}
- SOURCE_DIR ${WARPCTC_SOURCE_DIR}
- #UPDATE_COMMAND ""
- PATCH_COMMAND ""
- BUILD_ALWAYS 1
- CMAKE_ARGS -DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER}
- -DCMAKE_C_COMPILER=${CMAKE_C_COMPILER}
- -DCMAKE_C_FLAGS=${CMAKE_C_FLAGS}
- -DCMAKE_C_FLAGS_DEBUG=${CMAKE_C_FLAGS_DEBUG}
- -DCMAKE_C_FLAGS_RELEASE=${CMAKE_C_FLAGS_RELEASE}
- -DCMAKE_CXX_FLAGS=${CMAKE_CXX_FLAGS}
- -DCMAKE_CXX_FLAGS_RELEASE=${CMAKE_CXX_FLAGS_RELEASE}
- -DCMAKE_CXX_FLAGS_DEBUG=${CMAKE_CXX_FLAGS_DEBUG}
- -DCMAKE_INSTALL_PREFIX=${WARPCTC_INSTALL_DIR}
- -DWITH_GPU=${WITH_GPU}
- -DWITH_OMP=${USE_OMP}
- -DWITH_TORCH=OFF
- -DCMAKE_DISABLE_FIND_PACKAGE_Torch=ON
- -DBUILD_SHARED=ON
- -DBUILD_TESTS=OFF
- -DCMAKE_POSITION_INDEPENDENT_CODE=ON
- -DCMAKE_BUILD_TYPE=${THIRD_PARTY_BUILD_TYPE}
- ${EXTERNAL_OPTIONAL_ARGS}
- CMAKE_CACHE_ARGS -DCMAKE_BUILD_TYPE:STRING=${THIRD_PARTY_BUILD_TYPE}
- -DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=ON
- -DCMAKE_INSTALL_PREFIX:PATH=${WARPCTC_INSTALL_DIR}
-)
+if(WITH_ASCEND OR WITH_ASCEND_CL)
+ ExternalProject_Add(
+ extern_warpctc
+ ${EXTERNAL_PROJECT_LOG_ARGS}
+ ${SHALLOW_CLONE}
+ "${WARPCTC_DOWNLOAD_CMD}"
+ PREFIX ${WARPCTC_PREFIX_DIR}
+ SOURCE_DIR ${WARPCTC_SOURCE_DIR}
+ #UPDATE_COMMAND ""
+ PATCH_COMMAND ""
+ BUILD_ALWAYS 1
+ CMAKE_ARGS -DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER}
+ -DCMAKE_C_COMPILER=${CMAKE_C_COMPILER}
+ -DCMAKE_C_FLAGS=${CMAKE_C_FLAGS}
+ -DCMAKE_C_FLAGS_DEBUG=${CMAKE_C_FLAGS_DEBUG}
+ -DCMAKE_C_FLAGS_RELEASE=${CMAKE_C_FLAGS_RELEASE}
+ "-DCMAKE_CXX_FLAGS=${CMAKE_CXX_FLAGS}"
+ -DCMAKE_CXX_FLAGS_RELEASE=${CMAKE_CXX_FLAGS_RELEASE}
+ -DCMAKE_CXX_FLAGS_DEBUG=${CMAKE_CXX_FLAGS_DEBUG}
+ -DCMAKE_INSTALL_PREFIX=${WARPCTC_INSTALL_DIR}
+ -DWITH_GPU=${WITH_GPU}
+ -DWITH_ROCM=${WITH_ROCM}
+ -DWITH_OMP=${USE_OMP}
+ -DWITH_TORCH=OFF
+ -DCMAKE_DISABLE_FIND_PACKAGE_Torch=ON
+ -DBUILD_SHARED=ON
+ -DBUILD_TESTS=OFF
+ -DCMAKE_POSITION_INDEPENDENT_CODE=ON
+ -DCMAKE_BUILD_TYPE=${THIRD_PARTY_BUILD_TYPE}
+ ${EXTERNAL_OPTIONAL_ARGS}
+ CMAKE_CACHE_ARGS -DCMAKE_BUILD_TYPE:STRING=${THIRD_PARTY_BUILD_TYPE}
+ -DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=ON
+ -DCMAKE_INSTALL_PREFIX:PATH=${WARPCTC_INSTALL_DIR}
+ )
+else()
+ if(WIN32)
+ set(WARPCTC_C_FLAGS $)
+ set(WARPCTC_C_FLAGS_DEBUG $)
+ set(WARPCTC_C_FLAGS_RELEASE $)
+ set(WARPCTC_CXX_FLAGS $)
+ set(WARPCTC_CXX_FLAGS_RELEASE $)
+ set(WARPCTC_CXX_FLAGS_DEBUG $)
+ else()
+ set(WARPCTC_C_FLAGS ${CMAKE_C_FLAGS})
+ set(WARPCTC_C_FLAGS_DEBUG ${CMAKE_C_FLAGS_DEBUG})
+ set(WARPCTC_C_FLAGS_RELEASE ${CMAKE_C_FLAGS_RELEASE})
+ set(WARPCTC_CXX_FLAGS ${CMAKE_CXX_FLAGS})
+ set(WARPCTC_CXX_FLAGS_RELEASE ${CMAKE_CXX_FLAGS_RELEASE})
+ set(WARPCTC_CXX_FLAGS_DEBUG ${CMAKE_CXX_FLAGS_DEBUG})
+ endif()
+ ExternalProject_Add(
+ extern_warpctc
+ ${EXTERNAL_PROJECT_LOG_ARGS}
+ ${SHALLOW_CLONE}
+ "${WARPCTC_DOWNLOAD_CMD}"
+ PREFIX ${WARPCTC_PREFIX_DIR}
+ SOURCE_DIR ${WARPCTC_SOURCE_DIR}
+ #UPDATE_COMMAND ""
+ PATCH_COMMAND ""
+ BUILD_ALWAYS 1
+ CMAKE_ARGS -DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER}
+ -DCMAKE_C_COMPILER=${CMAKE_C_COMPILER}
+ -DCMAKE_C_FLAGS=${WARPCTC_C_FLAGS}
+ -DCMAKE_C_FLAGS_DEBUG=${WARPCTC_C_FLAGS_DEBUG}
+ -DCMAKE_C_FLAGS_RELEASE=${WARPCTC_C_FLAGS_RELEASE}
+ -DCMAKE_CXX_FLAGS=${WARPCTC_CXX_FLAGS}
+ -DCMAKE_CXX_FLAGS_RELEASE=${WARPCTC_CXX_FLAGS_RELEASE}
+ -DCMAKE_CXX_FLAGS_DEBUG=${WARPCTC_CXX_FLAGS_DEBUG}
+ -DCMAKE_INSTALL_PREFIX=${WARPCTC_INSTALL_DIR}
+ -DWITH_GPU=${WITH_GPU}
+ -DWITH_ROCM=${WITH_ROCM}
+ -DWITH_OMP=${USE_OMP}
+ -DWITH_TORCH=OFF
+ -DCMAKE_DISABLE_FIND_PACKAGE_Torch=ON
+ -DBUILD_SHARED=ON
+ -DBUILD_TESTS=OFF
+ -DCMAKE_POSITION_INDEPENDENT_CODE=ON
+ -DCMAKE_BUILD_TYPE=${THIRD_PARTY_BUILD_TYPE}
+ ${EXTERNAL_OPTIONAL_ARGS}
+ CMAKE_CACHE_ARGS -DCMAKE_BUILD_TYPE:STRING=${THIRD_PARTY_BUILD_TYPE}
+ -DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=ON
+ -DCMAKE_INSTALL_PREFIX:PATH=${WARPCTC_INSTALL_DIR}
+ )
+endif()
+
+
IF(WIN32)
SET(WARPCTC_LIBRARIES "${WARPCTC_INSTALL_DIR}/bin/warpctc${CMAKE_SHARED_LIBRARY_SUFFIX}"
CACHE FILEPATH "Warp-ctc Library" FORCE)
diff --git a/cmake/external/xpu.cmake b/cmake/external/xpu.cmake
index b5a3f0154745b9425c3dfc45a129117238fa80de..f846623602ed79a5bd84268436a59ede1957364b 100644
--- a/cmake/external/xpu.cmake
+++ b/cmake/external/xpu.cmake
@@ -13,7 +13,7 @@ if(NOT XPU_SDK_ROOT)
elseif(WITH_SUNWAY)
SET(XPU_URL "https://baidu-kunlun-public.su.bcebos.com/paddle_depence/sunway/xpu_2021_01_13.tar.gz" CACHE STRING "" FORCE)
else()
- SET(XPU_URL "https://baidu-kunlun-public.su.bcebos.com/paddle_depence/xpu_2021_02_27.tar.gz" CACHE STRING "" FORCE)
+ SET(XPU_URL "https://baidu-kunlun-public.su.bcebos.com/paddle_depence/xpu_2021_04_09.tar.gz" CACHE STRING "" FORCE)
endif()
SET(XPU_SOURCE_DIR "${THIRD_PARTY_PATH}/xpu")
diff --git a/cmake/flags.cmake b/cmake/flags.cmake
index e110524dd1abb864649daf8bd763e69ae87c600d..a2ddad557c2956f7de21bceaf7a6699e8dfbed43 100644
--- a/cmake/flags.cmake
+++ b/cmake/flags.cmake
@@ -4,10 +4,10 @@ include(CheckCCompilerFlag)
include(CheckCXXSymbolExists)
include(CheckTypeSize)
-function(CheckCompilerCXX11Flag)
+function(CheckCompilerCXX14Flag)
if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU")
- if(${CMAKE_CXX_COMPILER_VERSION} VERSION_LESS 4.8)
- message(FATAL_ERROR "Unsupported GCC version. GCC >= 4.8 required.")
+ if(${CMAKE_CXX_COMPILER_VERSION} VERSION_LESS 5.4)
+ message(FATAL_ERROR "Unsupported GCC version. GCC >= 5.4 required.")
elseif(${CMAKE_CXX_COMPILER_VERSION} VERSION_GREATER 8.2)
message(WARNING "Found GCC ${CMAKE_CXX_COMPILER_VERSION} which is too high, recommended to use GCC 8.2")
endif()
@@ -20,23 +20,15 @@ function(CheckCompilerCXX11Flag)
message(FATAL_ERROR "Unsupported AppleClang version. AppleClang >= 5.1 required.")
endif()
else()
- if (${CMAKE_CXX_COMPILER_VERSION} VERSION_LESS 3.3)
- message(FATAL_ERROR "Unsupported Clang version. Clang >= 3.3 required.")
+ if (${CMAKE_CXX_COMPILER_VERSION} VERSION_LESS 3.4)
+ message(FATAL_ERROR "Unsupported Clang version. Clang >= 3.4 required.")
endif()
endif()
endif()
endfunction()
-CheckCompilerCXX11Flag()
-if (WITH_GPU)
- if (${CMAKE_CUDA_COMPILER_VERSION} GREATER_EQUAL 11.0)
- set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++14")
- else()
- set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11")
- endif()
-else()
- set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11")
-endif()
+CheckCompilerCXX14Flag()
+set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++14")
# safe_set_flag
#
# Set a compile flag only if compiler is support
diff --git a/cmake/generic.cmake b/cmake/generic.cmake
index ba86cfabdf173467973b9d4337e6edbbe84c5889..a5c74a46631e9d76fa78261f706a1853a80bab32 100644
--- a/cmake/generic.cmake
+++ b/cmake/generic.cmake
@@ -447,9 +447,20 @@ function(cc_test TARGET_NAME)
cc_test_build(${TARGET_NAME}
SRCS ${cc_test_SRCS}
DEPS ${cc_test_DEPS})
- cc_test_run(${TARGET_NAME}
- COMMAND ${TARGET_NAME}
- ARGS ${cc_test_ARGS})
+ # we dont test hcom op, because it need complex configuration
+ # with more than one machine
+ if(NOT ("${TARGET_NAME}" STREQUAL "c_broadcast_op_npu_test" OR
+ "${TARGET_NAME}" STREQUAL "c_allreduce_sum_op_npu_test" OR
+ "${TARGET_NAME}" STREQUAL "c_allreduce_max_op_npu_test" OR
+ "${TARGET_NAME}" STREQUAL "c_reducescatter_op_npu_test" OR
+ "${TARGET_NAME}" STREQUAL "c_allgather_op_npu_test" OR
+ "${TARGET_NAME}" STREQUAL "send_v2_op_npu_test" OR
+ "${TARGET_NAME}" STREQUAL "c_reduce_sum_op_npu_test" OR
+ "${TARGET_NAME}" STREQUAL "recv_v2_op_npu_test"))
+ cc_test_run(${TARGET_NAME}
+ COMMAND ${TARGET_NAME}
+ ARGS ${cc_test_ARGS})
+ endif()
endif()
endfunction(cc_test)
@@ -492,10 +503,8 @@ function(nv_library TARGET_NAME)
message(FATAL "Please specify source file or library in nv_library.")
endif()
endif(nv_library_SRCS)
- if (WIN32 AND ${CMAKE_CUDA_COMPILER_VERSION} LESS 11.0)
- if(${MSVC_VERSION} LESS_EQUAL 1900)
- set_target_properties(${TARGET_NAME} PROPERTIES VS_USER_PROPS ${WIN_PROPS})
- endif()
+ if((CUDA_VERSION GREATER 9.2) AND (CUDA_VERSION LESS 11.0) AND (MSVC_VERSION LESS 1910))
+ set_target_properties(${TARGET_NAME} PROPERTIES VS_USER_PROPS ${WIN_PROPS})
endif()
endif()
endfunction(nv_library)
@@ -512,7 +521,7 @@ function(nv_binary TARGET_NAME)
add_dependencies(${TARGET_NAME} ${nv_binary_DEPS})
common_link(${TARGET_NAME})
endif()
- if (WIN32 AND ${CMAKE_CUDA_COMPILER_VERSION} LESS 11.0)
+ if((CUDA_VERSION GREATER 9.2) AND (CUDA_VERSION LESS 11.0) AND (MSVC_VERSION LESS 1910))
set_target_properties(${TARGET_NAME} PROPERTIES VS_USER_PROPS ${WIN_PROPS})
endif()
endif()
@@ -539,7 +548,7 @@ function(nv_test TARGET_NAME)
set_property(TEST ${TARGET_NAME} PROPERTY ENVIRONMENT FLAGS_cpu_deterministic=true)
set_property(TEST ${TARGET_NAME} PROPERTY ENVIRONMENT FLAGS_init_allocated_mem=true)
set_property(TEST ${TARGET_NAME} PROPERTY ENVIRONMENT FLAGS_cudnn_deterministic=true)
- if (WIN32 AND ${CMAKE_CUDA_COMPILER_VERSION} LESS 11.0)
+ if((CUDA_VERSION GREATER 9.2) AND (CUDA_VERSION LESS 11.0) AND (MSVC_VERSION LESS 1910))
set_target_properties(${TARGET_NAME} PROPERTIES VS_USER_PROPS ${WIN_PROPS})
endif()
endif()
@@ -809,7 +818,7 @@ function(py_test TARGET_NAME)
${PYTHON_EXECUTABLE} -u ${py_test_SRCS} ${py_test_ARGS}
WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR})
endif()
-
+
if (WIN32)
set_tests_properties(${TARGET_NAME} PROPERTIES TIMEOUT 150)
endif()
diff --git a/cmake/inference_lib.cmake b/cmake/inference_lib.cmake
index 2cba3d06936081097a773295c7f91e7aa53564a6..9694a7bc59c12a96e1c0c33488895ae94dbf2a03 100644
--- a/cmake/inference_lib.cmake
+++ b/cmake/inference_lib.cmake
@@ -192,6 +192,15 @@ include_directories(${CMAKE_BINARY_DIR}/../paddle/fluid/framework/io)
copy(inference_lib_dist
SRCS ${PADDLE_SOURCE_DIR}/paddle/fluid/extension/include/*
DSTS ${PADDLE_INFERENCE_INSTALL_DIR}/paddle/include/experimental/)
+copy(inference_lib_dist
+ SRCS ${PADDLE_SOURCE_DIR}/paddle/fluid/platform/complex64.h
+ DSTS ${PADDLE_INFERENCE_INSTALL_DIR}/paddle/include/experimental/)
+copy(inference_lib_dist
+ SRCS ${PADDLE_SOURCE_DIR}/paddle/fluid/platform/complex128.h
+ DSTS ${PADDLE_INFERENCE_INSTALL_DIR}/paddle/include/experimental/)
+copy(inference_lib_dist
+ SRCS ${PADDLE_SOURCE_DIR}/paddle/fluid/platform/float16.h
+ DSTS ${PADDLE_INFERENCE_INSTALL_DIR}/paddle/include/experimental/)
# CAPI inference library for only inference
set(PADDLE_INFERENCE_C_INSTALL_DIR "${CMAKE_BINARY_DIR}/paddle_inference_c_install_dir" CACHE STRING
@@ -202,11 +211,11 @@ set(src_dir "${PADDLE_SOURCE_DIR}/paddle/fluid")
if(WIN32)
set(paddle_inference_c_lib $/paddle_inference_c.*)
else(WIN32)
- set(paddle_inference_c_lib ${PADDLE_BINARY_DIR}/paddle/fluid/inference/capi/libpaddle_inference_c.*)
+ set(paddle_inference_c_lib ${PADDLE_BINARY_DIR}/paddle/fluid/inference/capi_exp/libpaddle_inference_c.*)
endif(WIN32)
copy(inference_lib_dist
- SRCS ${src_dir}/inference/capi/paddle_c_api.h ${paddle_inference_c_lib}
+ SRCS ${src_dir}/inference/capi_exp/pd_*.h ${paddle_inference_c_lib}
DSTS ${PADDLE_INFERENCE_C_INSTALL_DIR}/paddle/include ${PADDLE_INFERENCE_C_INSTALL_DIR}/paddle/lib)
# fluid library for both train and inference
diff --git a/cmake/init.cmake b/cmake/init.cmake
index aea02088750df4edc71a4909489c8ba250c8bb64..b11156d2e9986f879dcf4dd63354edb81c493260 100644
--- a/cmake/init.cmake
+++ b/cmake/init.cmake
@@ -18,6 +18,10 @@ if(NOT WIN32)
set(CMAKE_CXX_FLAGS_RELWITHDEBINFO "-O2 -g -DNDEBUG")
set(CMAKE_CXX_FLAGS_MINSIZEREL "-Os -DNDEBUG")
else()
+ # It can specify CUDA compile flag manualy,
+ # its use is to remvoe /Zi to reduce GPU static library size. But it's dangerous
+ # because CUDA will update by nvidia, then error will occur.
+ # Now, it's only used in VS2015 + CUDA:[10.0, 10.2]
set(WIN_PROPS ${CMAKE_SOURCE_DIR}/cmake/paddle_win.props)
endif()
diff --git a/cmake/operators.cmake b/cmake/operators.cmake
index 0343ff3cc292d97dcc77108735baa69c804468af..33390745cc8c96bc00b9eab84dfb637a8a76c2f9 100644
--- a/cmake/operators.cmake
+++ b/cmake/operators.cmake
@@ -11,6 +11,7 @@ function(op_library TARGET)
set(cu_cc_srcs)
set(hip_cc_srcs)
set(xpu_cc_srcs)
+ set(npu_cc_srcs)
set(cudnn_cu_cc_srcs)
set(miopen_cu_cc_srcs)
set(cudnn_cu_srcs)
@@ -20,6 +21,9 @@ function(op_library TARGET)
set(mkldnn_cc_srcs)
set(MKLDNN_FILE)
set(op_common_deps operator op_registry math_function layer common_infer_shape_functions)
+ if (WITH_ASCEND_CL)
+ set(op_common_deps ${op_common_deps} npu_op_runner)
+ endif()
# Option `UNITY` is used to specify that operator `TARGET` will compiles with Unity Build.
set(options UNITY)
set(oneValueArgs "")
@@ -40,6 +44,9 @@ function(op_library TARGET)
if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${TARGET}.cu)
list(APPEND cu_srcs ${TARGET}.cu)
endif()
+ if (WITH_NV_JETSON)
+ list(REMOVE_ITEM cu_srcs "decode_jpeg_op.cu")
+ endif()
if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${TARGET}.part.cu)
set(PART_CUDA_KERNEL_FILES ${CMAKE_CURRENT_SOURCE_DIR}/${TARGET}.part.cu
${PART_CUDA_KERNEL_FILES} PARENT_SCOPE)
@@ -85,6 +92,12 @@ function(op_library TARGET)
list(APPEND xpu_cc_srcs ${XPU_FILE}.cc)
endif()
endif()
+ if(WITH_ASCEND_CL)
+ string(REPLACE "_op" "_op_npu" NPU_FILE "${TARGET}")
+ if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${NPU_FILE}.cc)
+ list(APPEND npu_cc_srcs ${NPU_FILE}.cc)
+ endif()
+ endif()
else()
foreach(src ${op_library_SRCS})
if(WITH_ROCM AND ${src} MATCHES ".*_cudnn_op.cu$")
@@ -107,6 +120,8 @@ function(op_library TARGET)
list(APPEND cu_cc_srcs ${src})
elseif(WITH_XPU AND ${src} MATCHES ".*_op_xpu.cc$")
list(APPEND xpu_cc_srcs ${src})
+ elseif(WITH_ASCEND_CL AND ${src} MATCHES ".*_op_npu.cc$")
+ list(APPEND npu_cc_srcs ${src})
elseif(${src} MATCHES ".*\\.cc$")
list(APPEND cc_srcs ${src})
else()
@@ -168,15 +183,15 @@ function(op_library TARGET)
list(REMOVE_ITEM miopen_cu_cc_srcs "affine_grid_cudnn_op.cu.cc")
list(REMOVE_ITEM miopen_cu_cc_srcs "grid_sampler_cudnn_op.cu.cc")
list(REMOVE_ITEM hip_srcs "cholesky_op.cu")
- list(REMOVE_ITEM hip_srcs "correlation_op.cu")
list(REMOVE_ITEM hip_srcs "multinomial_op.cu")
+ list(REMOVE_ITEM hip_srcs "decode_jpeg_op.cu")
hip_library(${TARGET} SRCS ${cc_srcs} ${hip_cc_srcs} ${miopen_cu_cc_srcs} ${miopen_cu_srcs} ${mkldnn_cc_srcs} ${hip_srcs} DEPS ${op_library_DEPS}
${op_common_deps})
else()
# Unity Build relies on global option `WITH_UNITY_BUILD` and local option `UNITY`.
if(WITH_UNITY_BUILD AND op_library_UNITY)
# Combine the cc source files.
- compose_unity_target_sources(${UNITY_TARGET} cc ${cc_srcs} ${mkldnn_cc_srcs} ${xpu_cc_srcs})
+ compose_unity_target_sources(${UNITY_TARGET} cc ${cc_srcs} ${mkldnn_cc_srcs} ${xpu_cc_srcs} ${npu_cc_srcs})
if(TARGET ${UNITY_TARGET})
# If `UNITY_TARGET` exists, add source files to `UNITY_TARGET`.
target_sources(${UNITY_TARGET} PRIVATE ${unity_target_cc_sources})
@@ -187,7 +202,7 @@ function(op_library TARGET)
# Add alias library to handle dependencies.
add_library(${TARGET} ALIAS ${UNITY_TARGET})
else()
- cc_library(${TARGET} SRCS ${cc_srcs} ${mkldnn_cc_srcs} ${xpu_cc_srcs} DEPS ${op_library_DEPS}
+ cc_library(${TARGET} SRCS ${cc_srcs} ${mkldnn_cc_srcs} ${xpu_cc_srcs} ${npu_cc_srcs} DEPS ${op_library_DEPS}
${op_common_deps})
endif()
endif()
@@ -207,6 +222,7 @@ function(op_library TARGET)
# The registration of USE_OP, please refer to paddle/fluid/framework/op_registry.h.
# Note that it's enough to just adding one operator to pybind in a *_op.cc file.
# And for detail pybind information, please see generated paddle/pybind/pybind.h.
+ set(ORIGINAL_TARGET ${TARGET})
file(READ ${TARGET}.cc TARGET_CONTENT)
string(REGEX MATCH "REGISTER_OPERATOR\\(.*REGISTER_OPERATOR\\(" multi_register "${TARGET_CONTENT}")
# [ \t\r\n]* is used for blank characters
@@ -239,8 +255,9 @@ function(op_library TARGET)
list(LENGTH mkldnn_cc_srcs mkldnn_cc_srcs_len)
list(LENGTH xpu_cc_srcs xpu_cc_srcs_len)
list(LENGTH miopen_cu_cc_srcs miopen_cu_cc_srcs_len)
+ list(LENGTH npu_cc_srcs npu_cc_srcs_len)
if (${pybind_flag} EQUAL 0 AND ${mkldnn_cc_srcs_len} EQUAL 0 AND ${cu_srcs_len} EQUAL 0 AND ${cu_cc_srcs_len} EQUAL 0 AND
- ${hip_srcs_len} EQUAL 0 AND ${hip_cc_srcs_len} EQUAL 0 AND ${miopen_cu_cc_srcs_len} EQUAL 0 AND ${xpu_cc_srcs_len} EQUAL 0)
+ ${hip_srcs_len} EQUAL 0 AND ${hip_cc_srcs_len} EQUAL 0 AND ${miopen_cu_cc_srcs_len} EQUAL 0 AND ${xpu_cc_srcs_len} EQUAL 0 AND ${npu_cc_srcs_len} EQUAL 0)
file(APPEND ${pybind_file} "USE_CPU_ONLY_OP(${TARGET});\n")
set(pybind_flag 1)
endif()
@@ -280,6 +297,26 @@ function(op_library TARGET)
if (WITH_XPU AND ${xpu_cc_srcs_len} GREATER 0)
file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(${TARGET}, XPU);\n")
endif()
+
+ if (WITH_ASCEND_CL AND ${npu_cc_srcs_len} GREATER 0)
+ file(READ ${ORIGINAL_TARGET}_npu.cc TARGET_NPU_CONTENT)
+ # It is different from the logic above, becareful
+ string(REGEX MATCH "REGISTER_OP_NPU_KERNEL\\(.*" multi_npu_register "${TARGET_NPU_CONTENT}")
+ # [ \t\r\n]* is used for blank characters
+ string(REGEX MATCH "REGISTER_OP_NPU_KERNEL\\([ \t\r\n]*[a-z0-9_]*," one_npu_register "${multi_npu_register}")
+
+ if (one_npu_register STREQUAL "")
+ string(REPLACE "_op" "" NPU_TARGET "${TARGET}")
+ else ()
+ string(REPLACE "REGISTER_OP_NPU_KERNEL(" "" NPU_TARGET "${one_npu_register}")
+ string(REPLACE "," "" NPU_TARGET "${NPU_TARGET}")
+ # [ \t\r\n]+ is used for blank characters.
+ # Here we use '+' instead of '*' since it is a REPLACE operation.
+ string(REGEX REPLACE "[ \t\r\n]+" "" NPU_TARGET "${NPU_TARGET}")
+ endif()
+ file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(${NPU_TARGET}, NPU);\n")
+ endif()
+
# pybind USE_OP_DEVICE_KERNEL for MKLDNN
if (WITH_MKLDNN AND ${mkldnn_cc_srcs_len} GREATER 0)
# Append first implemented MKLDNN activation operator
@@ -330,6 +367,7 @@ function(register_operators)
file(GLOB OPS RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "*_op.cc")
string(REPLACE "_mkldnn" "" OPS "${OPS}")
string(REPLACE "_xpu" "" OPS "${OPS}")
+ string(REPLACE "_npu" "" OPS "${OPS}")
string(REPLACE ".cc" "" OPS "${OPS}")
list(REMOVE_DUPLICATES OPS)
list(LENGTH register_operators_DEPS register_operators_DEPS_len)
diff --git a/cmake/paddle_win.props b/cmake/paddle_win.props
index 0115ad4b59fc466ea10be6912257c40d31ed3640..3c069bd2981c437a1450ede29db2449dc46a9a4a 100644
--- a/cmake/paddle_win.props
+++ b/cmake/paddle_win.props
@@ -15,7 +15,7 @@
InheritFromHost
-ccbin "%(VCBinDir)" -x cu [GenerateRelocatableDeviceCode] [Include] [RequiredIncludes] [InterleaveSourceInPTX] [GPUDebugInfo] [GenerateLineInfo] [Keep] [KeepDir] [MaxRegCount] [PtxAsOptionV] [TargetMachinePlatform] [NvccCompilation] [CudaRuntime] [AdditionalOptions]
- --use-local-env --cl-version $(CudaClVersion)
+ --use-local-env $(CudaClVersion)
[CodeGeneration]
-clean
@@ -88,4 +88,3 @@ set CUDAFE_FLAGS=--sdk_dir "$(WindowsSdkDir)"
-
diff --git a/cmake/third_party.cmake b/cmake/third_party.cmake
index 6488d29afc5f7f4af72aab1cf2463d900a89fa9d..56edaff2a50dab0f7029ec1e85fc3d4ce8ac416e 100644
--- a/cmake/third_party.cmake
+++ b/cmake/third_party.cmake
@@ -29,9 +29,9 @@ set(third_party_deps)
# 2. REPOSITORY: specify git REPOSITORY of 3rd party
# 3. TAG: specify git tag/branch/commitID of 3rd party
# 4. DIR: overwrite the original SOURCE_DIR when cache directory
-#
+#
# The function Return 1 PARENT_SCOPE variables:
-# - ${TARGET}_DOWNLOAD_CMD: Simply place "${TARGET}_DOWNLOAD_CMD" in ExternalProject_Add,
+# - ${TARGET}_DOWNLOAD_CMD: Simply place "${TARGET}_DOWNLOAD_CMD" in ExternalProject_Add,
# and you no longer need to set any donwnload steps in ExternalProject_Add.
# For example:
# Cache_third_party(${TARGET}
@@ -52,7 +52,7 @@ FUNCTION(cache_third_party TARGET)
SET(${TARGET_NAME}_DOWNLOAD_CMD
GIT_REPOSITORY ${cache_third_party_REPOSITORY})
IF(cache_third_party_TAG)
- LIST(APPEND ${TARGET_NAME}_DOWNLOAD_CMD
+ LIST(APPEND ${TARGET_NAME}_DOWNLOAD_CMD
GIT_TAG ${cache_third_party_TAG})
ENDIF()
ELSEIF(cache_third_party_URL)
@@ -130,7 +130,7 @@ ENDFUNCTION()
# Correction of flags on different Platform(WIN/MAC) and Print Warning Message
if (APPLE)
if(WITH_MKL)
- MESSAGE(WARNING
+ MESSAGE(WARNING
"Mac is not supported with MKL in Paddle yet. Force WITH_MKL=OFF.")
set(WITH_MKL OFF CACHE STRING "Disable MKL for building on mac" FORCE)
endif()
@@ -141,7 +141,7 @@ if(WIN32 OR APPLE)
SET(WITH_XBYAK OFF CACHE STRING "Disable XBYAK in Windows and MacOS" FORCE)
if(WITH_LIBXSMM)
- MESSAGE(WARNING
+ MESSAGE(WARNING
"Windows, Mac are not supported with libxsmm in Paddle yet."
"Force WITH_LIBXSMM=OFF")
SET(WITH_LIBXSMM OFF CACHE STRING "Disable LIBXSMM in Windows and MacOS" FORCE)
@@ -261,6 +261,14 @@ if(WITH_PSLIB)
if(WITH_PSLIB_BRPC)
include(external/pslib_brpc) # download, build, install pslib_brpc
list(APPEND third_party_deps extern_pslib_brpc)
+ else()
+ include(external/snappy)
+ list(APPEND third_party_deps extern_snappy)
+
+ include(external/leveldb)
+ list(APPEND third_party_deps extern_leveldb)
+ include(external/brpc)
+ list(APPEND third_party_deps extern_brpc)
endif()
endif(WITH_PSLIB)
@@ -274,10 +282,15 @@ if(WITH_BOX_PS)
list(APPEND third_party_deps extern_box_ps)
endif(WITH_BOX_PS)
-if(WITH_ASCEND)
+if(WITH_ASCEND OR WITH_ASCEND_CL)
include(external/ascend)
- list(APPEND third_party_deps extern_ascend)
-endif (WITH_ASCEND)
+ if(WITH_ASCEND OR WITH_ASCEND_CL)
+ list(APPEND third_party_deps extern_ascend)
+ endif()
+ if(WITH_ASCEND_CL)
+ list(APPEND third_party_deps extern_ascend_cl)
+ endif()
+endif ()
if (WITH_PSCORE)
include(external/snappy)
@@ -285,7 +298,7 @@ if (WITH_PSCORE)
include(external/leveldb)
list(APPEND third_party_deps extern_leveldb)
-
+
include(external/brpc)
list(APPEND third_party_deps extern_brpc)
diff --git a/go/README_cn.md b/go/README_cn.md
index a184ecbb8dea1ae71074ef9686d088a5f4cf0f33..040540e939bc3a0993e7c963b281ad91fbfe1ffc 100644
--- a/go/README_cn.md
+++ b/go/README_cn.md
@@ -50,6 +50,7 @@ output_data := value.Interface().([][]float32)
运行
```bash
+go mod init github.com/paddlepaddle
export LD_LIBRARY_PATH=`pwd`/paddle_c/paddle/lib:$LD_LIBRARY_PATH
go run ./demo/mobilenet.go
```
diff --git a/go/demo/mobilenet.go b/go/demo/mobilenet.go
index 1b42fe8049a584616da7b4940fd19a89df9bc52b..c1ca2e967f72dc6646a6785d86ba59c709bfe25c 100644
--- a/go/demo/mobilenet.go
+++ b/go/demo/mobilenet.go
@@ -13,7 +13,7 @@
// limitations under the License.
package main
-import "../paddle"
+import "github.com/paddlepaddle/paddle"
import "strings"
import "io/ioutil"
import "strconv"
diff --git a/go/demo/mobilenet_c_exp.cc b/go/demo/mobilenet_c_exp.cc
new file mode 100644
index 0000000000000000000000000000000000000000..b4f42dab6790bfb6dd33860a8ada704166bb74ac
--- /dev/null
+++ b/go/demo/mobilenet_c_exp.cc
@@ -0,0 +1,84 @@
+// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+#include
+#include
+#include
+
+void ReadData(float* data, int size);
+
+int main(int argc, char* argv[]) {
+ PD_Config* config = PD_ConfigCreate();
+ PD_ConfigSetModel(config, "data/model/__model__", "data/model/__params__");
+ PD_ConfigDisableGlogInfo(config);
+
+ PD_Predictor* predictor = PD_PredictorCreate(config);
+ // config has destroyed in PD_PredictorCreate
+ config = NULL;
+
+ int input_num = PD_PredictorGetInputNum(predictor);
+ printf("Input num: %d\n", input_num);
+ int output_num = PD_PredictorGetOutputNum(predictor);
+ printf("Output num: %d\n", output_num);
+
+ PD_OneDimArrayCstr* input_names = PD_PredictorGetInputNames(predictor);
+ PD_Tensor* input_tensor =
+ PD_PredictorGetInputHandle(predictor, input_names->data[0]);
+ PD_OneDimArrayCstrDestroy(input_names);
+ input_names = NULL;
+
+ int32_t shape[] = {1, 3, 300, 300};
+ float* data = (float*)malloc(sizeof(float) * 1 * 3 * 300 * 300); // NOLINT
+ ReadData(data, 1 * 3 * 300 * 300); // NOLINT
+ PD_TensorReshape(input_tensor, 4, shape);
+ PD_TensorCopyFromCpuFloat(input_tensor, data);
+ free(data);
+ data = NULL;
+ PD_PredictorRun(predictor);
+
+ PD_OneDimArrayCstr* output_names = PD_PredictorGetOutputNames(predictor);
+ PD_Tensor* output_tensor =
+ PD_PredictorGetOutputHandle(predictor, output_names->data[0]);
+ PD_OneDimArrayCstrDestroy(output_names);
+ output_names = nullptr;
+
+ PD_OneDimArrayInt32* out_shape = PD_TensorGetShape(output_tensor);
+ int32_t size = 1;
+ for (size_t index = 0; index < out_shape->size; ++index) {
+ size = size * out_shape->data[index];
+ }
+ PD_OneDimArrayInt32Destroy(out_shape);
+ out_shape = NULL;
+
+ data = (float*)malloc(sizeof(float) * size); // NOLINT
+ PD_TensorCopyToCpuFloat(output_tensor, data);
+ free(data);
+ data = NULL;
+
+ PD_TensorDestroy(output_tensor);
+ output_tensor = NULL;
+ PD_TensorDestroy(input_tensor);
+ input_tensor = NULL;
+ PD_PredictorDestroy(predictor);
+ predictor = NULL;
+
+ return 0;
+}
+
+void ReadData(float* data, int n) {
+ FILE* fp = fopen("data/data.txt", "r");
+ for (int i = 0; i < n; i++) {
+ fscanf(fp, "%f", &data[i]);
+ }
+ fclose(fp);
+}
diff --git a/go/paddle/common.go b/go/paddle/common.go
index 4bf947659312824216e6003cb2f150ae39a94d00..cbbde6a45f59b80931a3a2c501581819085e8ea7 100644
--- a/go/paddle/common.go
+++ b/go/paddle/common.go
@@ -15,7 +15,7 @@
package paddle
// #cgo CFLAGS: -I${SRCDIR}/../paddle_c/paddle/include
-// #cgo LDFLAGS: -L${SRCDIR}/../paddle_c/paddle/lib -lpaddle_fluid_c
+// #cgo LDFLAGS: -L${SRCDIR}/../paddle_c/paddle/lib -lpaddle_inference_c
// #include
// #include
import "C"
diff --git a/go/paddle/config.go b/go/paddle/config.go
index 89f7d7e63ff2a858f058ad22ea424b29f66a4477..68a31230997bed73fbab1c1d1c7af123e353cf97 100644
--- a/go/paddle/config.go
+++ b/go/paddle/config.go
@@ -15,7 +15,7 @@
package paddle
// #cgo CFLAGS: -I${SRCDIR}/../paddle_c/paddle/include
-// #cgo LDFLAGS: -L${SRCDIR}/../paddle_c/paddle/lib -lpaddle_fluid_c
+// #cgo LDFLAGS: -L${SRCDIR}/../paddle_c/paddle/lib -lpaddle_inference_c
// #include
// #include
// #include
diff --git a/go/paddle/predictor.go b/go/paddle/predictor.go
index 59bad908e6a5082e38b8bb33c849aa1097107d76..5f2b2c81a60549dfdbf22dd31a98560e7e3a8cee 100644
--- a/go/paddle/predictor.go
+++ b/go/paddle/predictor.go
@@ -15,7 +15,7 @@
package paddle
// #cgo CFLAGS: -I${SRCDIR}/../paddle_c/paddle/include
-// #cgo LDFLAGS: -L${SRCDIR}/../paddle_c/paddle/lib -lpaddle_fluid_c
+// #cgo LDFLAGS: -L${SRCDIR}/../paddle_c/paddle/lib -lpaddle_inference_c
// #include
// #include "paddle_c_api.h"
import "C"
@@ -88,7 +88,7 @@ func (predictor *Predictor) GetInputNames() []string {
}
func (predictor *Predictor) GetOutputNames() []string {
- names := make([]string, predictor.GetInputNum())
+ names := make([]string, predictor.GetOutputNum())
for i := 0; i < len(names); i++ {
names[i] = predictor.GetOutputName(i)
}
diff --git a/go/paddle/tensor.go b/go/paddle/tensor.go
index e6e2c53fef1af565d4efba976d10839efe22517d..6fbcf039f88a7cc43a5d28f0433c9feb965566f0 100644
--- a/go/paddle/tensor.go
+++ b/go/paddle/tensor.go
@@ -15,7 +15,7 @@
package paddle
// #cgo CFLAGS: -I${SRCDIR}/../paddle_c/paddle/include
-// #cgo LDFLAGS: -L${SRCDIR}/../paddle_c/paddle/lib -lpaddle_fluid_c
+// #cgo LDFLAGS: -L${SRCDIR}/../paddle_c/paddle/lib -lpaddle_inference_c
// #include
// #include
// #include
@@ -209,7 +209,7 @@ func DecodeTensor(r *bytes.Reader, shape []int32, t reflect.Type, ptr reflect.Va
value := reflect.Indirect(ptr)
value.Set(reflect.MakeSlice(t, int(shape[0]), int(shape[0])))
if len(shape) == 1 && value.Len() > 0 {
- switch value.Index(1).Kind() {
+ switch value.Index(0).Kind() {
case reflect.Uint8, reflect.Int32, reflect.Int64, reflect.Float32:
binary.Read(r, Endian(), value.Interface())
return
diff --git a/paddle/extension.h b/paddle/extension.h
index 71469576853a33b9158713304a68c6ac757aab4f..98d4bfd0326c5c524fcac9129f58d0ae99fc8afe 100644
--- a/paddle/extension.h
+++ b/paddle/extension.h
@@ -15,4 +15,4 @@ limitations under the License. */
#pragma once
// All paddle apis in C++ frontend
-#include "paddle/fluid/extension/include/ext_all.h"
+#include "paddle/extension/include/ext_all.h"
diff --git a/paddle/fluid/CMakeLists.txt b/paddle/fluid/CMakeLists.txt
index c18332d3b873164a725a25316fc611aa7e7a3092..dcff02a662e2734bc66d4cf219fce527fd0961aa 100644
--- a/paddle/fluid/CMakeLists.txt
+++ b/paddle/fluid/CMakeLists.txt
@@ -9,4 +9,3 @@ add_subdirectory(pybind)
# NOTE: please add subdirectory inference at last.
add_subdirectory(inference)
-add_subdirectory(train)
diff --git a/paddle/fluid/distributed/CMakeLists.txt b/paddle/fluid/distributed/CMakeLists.txt
index 5a2d7a06201ba4acff679ffcfee87fde8d025ed6..905347d031b35b39b43879c7bd78ab39e933a5b3 100644
--- a/paddle/fluid/distributed/CMakeLists.txt
+++ b/paddle/fluid/distributed/CMakeLists.txt
@@ -11,9 +11,10 @@ if (CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 7.0)
"${DISTRIBUTE_COMPILE_FLAGS} -faligned-new")
endif()
-add_subdirectory(table)
add_subdirectory(service)
+add_subdirectory(table)
add_subdirectory(test)
+add_subdirectory(index_dataset)
get_property(RPC_DEPS GLOBAL PROPERTY RPC_DEPS)
diff --git a/paddle/fluid/distributed/fleet.cc b/paddle/fluid/distributed/fleet.cc
index b638af49730dd4800109729c9d91afa82efa80e4..dfd55f16e1a065e46b2186a6a589eabc1ac3b431 100644
--- a/paddle/fluid/distributed/fleet.cc
+++ b/paddle/fluid/distributed/fleet.cc
@@ -177,8 +177,11 @@ std::future FleetWrapper::PullSparseVarsAsync(
for (auto& t : *fea_values) {
pull_result_ptr.push_back(t.data());
}
- return pserver_ptr_->_worker_ptr->pull_sparse(
- pull_result_ptr.data(), table_id, fea_keys->data(), fea_keys->size());
+
+ bool training = true;
+ return pserver_ptr_->_worker_ptr->pull_sparse(pull_result_ptr.data(),
+ table_id, fea_keys->data(),
+ fea_keys->size(), training);
}
void FleetWrapper::PullSparseVarsSync(
@@ -224,8 +227,10 @@ void FleetWrapper::PullSparseVarsSync(
for (auto& t : *fea_values) {
pull_result_ptr.push_back(t.data());
}
+ bool training = true;
auto status = pserver_ptr_->_worker_ptr->pull_sparse(
- pull_result_ptr.data(), table_id, fea_keys->data(), fea_keys->size());
+ pull_result_ptr.data(), table_id, fea_keys->data(), fea_keys->size(),
+ training);
pull_sparse_status.push_back(std::move(status));
for (auto& t : pull_sparse_status) {
t.wait();
@@ -238,9 +243,13 @@ void FleetWrapper::PullSparseVarsSync(
}
}
+// is_training is true means training, false means inference, the behavior is
+// different on pserver
+
void FleetWrapper::PullSparseToTensorSync(const uint64_t table_id, int fea_dim,
uint64_t padding_id,
platform::Place place,
+ bool is_training,
std::vector* inputs,
std::vector* outputs) {
std::vector fea_keys;
@@ -279,7 +288,8 @@ void FleetWrapper::PullSparseToTensorSync(const uint64_t table_id, int fea_dim,
}
auto* communicator = Communicator::GetInstance();
auto status = communicator->_worker_ptr->pull_sparse(
- pull_result_ptr.data(), table_id, fea_keys.data(), fea_keys.size());
+ pull_result_ptr.data(), table_id, fea_keys.data(), fea_keys.size(),
+ is_training);
status.wait();
auto ret = status.get();
if (ret != 0) {
diff --git a/paddle/fluid/distributed/fleet.h b/paddle/fluid/distributed/fleet.h
index ac566606ddcb4024eeaf7b846c894f7f5cdafa82..0da5d1e2bf987f38de3b9a03c659fc5e1841eca1 100644
--- a/paddle/fluid/distributed/fleet.h
+++ b/paddle/fluid/distributed/fleet.h
@@ -95,8 +95,12 @@ class FleetWrapper {
// Pull sparse variables from server in sync mode
// pull immediately to tensors
+ // is_training is true means training, false means inference, the behavior is
+ // different on pserver
+
void PullSparseToTensorSync(const uint64_t table_id, int fea_dim,
uint64_t padding_id, platform::Place place,
+ bool is_training,
std::vector* inputs, // NOLINT
std::vector* outputs); // NOLINT
diff --git a/paddle/fluid/distributed/index_dataset/CMakeLists.txt b/paddle/fluid/distributed/index_dataset/CMakeLists.txt
new file mode 100644
index 0000000000000000000000000000000000000000..a30488494a52bcfea61476caeb1ab08e3e6781a1
--- /dev/null
+++ b/paddle/fluid/distributed/index_dataset/CMakeLists.txt
@@ -0,0 +1,7 @@
+proto_library(index_dataset_proto SRCS index_dataset.proto)
+cc_library(index_wrapper SRCS index_wrapper.cc DEPS index_dataset_proto fs)
+cc_library(index_sampler SRCS index_sampler.cc DEPS index_wrapper)
+
+if(WITH_PYTHON)
+ py_proto_compile(index_dataset_py_proto SRCS index_dataset.proto)
+endif()
diff --git a/paddle/fluid/distributed/index_dataset/index_dataset.proto b/paddle/fluid/distributed/index_dataset/index_dataset.proto
new file mode 100644
index 0000000000000000000000000000000000000000..1b4ee313671ad503b9e46dbe9e34d4a69d0cfc4d
--- /dev/null
+++ b/paddle/fluid/distributed/index_dataset/index_dataset.proto
@@ -0,0 +1,32 @@
+// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+syntax = "proto2";
+package paddle.distributed;
+
+message IndexNode {
+ required uint64 id = 1;
+ required bool is_leaf = 2;
+ required float probability = 3;
+}
+
+message TreeMeta {
+ required int32 height = 1;
+ required int32 branch = 2;
+}
+
+message KVItem {
+ required bytes key = 1;
+ required bytes value = 2;
+}
\ No newline at end of file
diff --git a/paddle/fluid/distributed/index_dataset/index_sampler.cc b/paddle/fluid/distributed/index_dataset/index_sampler.cc
new file mode 100644
index 0000000000000000000000000000000000000000..3e573bbdd2de97130a109ddb583a724cf363c6be
--- /dev/null
+++ b/paddle/fluid/distributed/index_dataset/index_sampler.cc
@@ -0,0 +1,74 @@
+// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "paddle/fluid/distributed/index_dataset/index_sampler.h"
+
+namespace paddle {
+namespace distributed {
+
+std::vector> LayerWiseSampler::sample(
+ const std::vector>& user_inputs,
+ const std::vector& target_ids, bool with_hierarchy) {
+ auto input_num = target_ids.size();
+ auto user_feature_num = user_inputs[0].size();
+ std::vector> outputs(
+ input_num * layer_counts_sum_,
+ std::vector(user_feature_num + 2));
+
+ auto max_layer = tree_->Height();
+ size_t idx = 0;
+ for (size_t i = 0; i < input_num; i++) {
+ auto travel_codes =
+ tree_->GetTravelCodes(target_ids[i], start_sample_layer_);
+ auto travel_path = tree_->GetNodes(travel_codes);
+ for (size_t j = 0; j < travel_path.size(); j++) {
+ // user
+ if (j > 0 && with_hierarchy) {
+ auto ancestor_codes =
+ tree_->GetAncestorCodes(user_inputs[i], max_layer - j - 1);
+ auto hierarchical_user = tree_->GetNodes(ancestor_codes);
+ for (int idx_offset = 0; idx_offset <= layer_counts_[j]; idx_offset++) {
+ for (size_t k = 0; k < user_feature_num; k++) {
+ outputs[idx + idx_offset][k] = hierarchical_user[k].id();
+ }
+ }
+ } else {
+ for (int idx_offset = 0; idx_offset <= layer_counts_[j]; idx_offset++) {
+ for (size_t k = 0; k < user_feature_num; k++) {
+ outputs[idx + idx_offset][k] = user_inputs[i][k];
+ }
+ }
+ }
+
+ // sampler ++
+ outputs[idx][user_feature_num] = travel_path[j].id();
+ outputs[idx][user_feature_num + 1] = 1.0;
+ idx += 1;
+ for (int idx_offset = 0; idx_offset < layer_counts_[j]; idx_offset++) {
+ int sample_res = 0;
+ do {
+ sample_res = sampler_vec_[j]->Sample();
+ } while (layer_ids_[j][sample_res].id() == travel_path[j].id());
+ outputs[idx + idx_offset][user_feature_num] =
+ layer_ids_[j][sample_res].id();
+ outputs[idx + idx_offset][user_feature_num + 1] = 0;
+ }
+ idx += layer_counts_[j];
+ }
+ }
+ return outputs;
+}
+
+} // end namespace distributed
+} // end namespace paddle
diff --git a/paddle/fluid/distributed/index_dataset/index_sampler.h b/paddle/fluid/distributed/index_dataset/index_sampler.h
new file mode 100644
index 0000000000000000000000000000000000000000..8813421446a21c1379ca872952fe8b367d0724ca
--- /dev/null
+++ b/paddle/fluid/distributed/index_dataset/index_sampler.h
@@ -0,0 +1,120 @@
+// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#pragma once
+#include
+#include "paddle/fluid/distributed/index_dataset/index_wrapper.h"
+#include "paddle/fluid/framework/program_desc.h"
+#include "paddle/fluid/operators/math/sampler.h"
+#include "paddle/fluid/platform/enforce.h"
+
+namespace paddle {
+namespace distributed {
+
+class IndexSampler {
+ public:
+ virtual ~IndexSampler() {}
+ IndexSampler() {}
+
+ template
+ static std::shared_ptr Init(const std::string& name) {
+ std::shared_ptr instance = nullptr;
+ instance.reset(new T(name));
+ return instance;
+ }
+
+ virtual void init_layerwise_conf(const std::vector& layer_sample_counts,
+ int start_sample_layer = 1, int seed = 0) {}
+ virtual void init_beamsearch_conf(const int64_t k) {}
+ virtual std::vector> sample(
+ const std::vector>& user_inputs,
+ const std::vector& input_targets,
+ bool with_hierarchy = false) = 0;
+};
+
+class LayerWiseSampler : public IndexSampler {
+ public:
+ virtual ~LayerWiseSampler() {}
+ explicit LayerWiseSampler(const std::string& name) {
+ tree_ = IndexWrapper::GetInstance()->get_tree_index(name);
+ }
+
+ void init_layerwise_conf(const std::vector& layer_sample_counts,
+ int start_sample_layer, int seed) override {
+ seed_ = seed;
+ start_sample_layer_ = start_sample_layer;
+
+ PADDLE_ENFORCE_GT(
+ start_sample_layer_, 0,
+ paddle::platform::errors::InvalidArgument(
+ "start sampler layer = [%d], it should greater than 0.",
+ start_sample_layer_));
+ PADDLE_ENFORCE_LT(start_sample_layer_, tree_->Height(),
+ paddle::platform::errors::InvalidArgument(
+ "start sampler layer = [%d], it should less than "
+ "max_layer, which is [%d].",
+ start_sample_layer_, tree_->Height()));
+
+ size_t i = 0;
+ layer_counts_sum_ = 0;
+ layer_counts_.clear();
+ int cur_layer = start_sample_layer_;
+ while (cur_layer < tree_->Height()) {
+ int layer_sample_num = 1;
+ if (i < layer_sample_counts.size()) {
+ layer_sample_num = layer_sample_counts[i];
+ }
+ layer_counts_sum_ += layer_sample_num + 1;
+ layer_counts_.push_back(layer_sample_num);
+ VLOG(3) << "[INFO] level " << cur_layer
+ << " sample_layer_counts.push_back: " << layer_sample_num;
+ cur_layer += 1;
+ i += 1;
+ }
+ reverse(layer_counts_.begin(), layer_counts_.end());
+ VLOG(3) << "sample counts sum: " << layer_counts_sum_;
+
+ auto max_layer = tree_->Height();
+ sampler_vec_.clear();
+ layer_ids_.clear();
+
+ auto layer_index = max_layer - 1;
+ size_t idx = 0;
+ while (layer_index >= start_sample_layer_) {
+ auto layer_codes = tree_->GetLayerCodes(layer_index);
+ layer_ids_.push_back(tree_->GetNodes(layer_codes));
+ auto sampler_temp =
+ std::make_shared(
+ layer_ids_[idx].size() - 1, seed_);
+ sampler_vec_.push_back(sampler_temp);
+ layer_index--;
+ idx++;
+ }
+ }
+ std::vector> sample(
+ const std::vector>& user_inputs,
+ const std::vector& target_ids, bool with_hierarchy) override;
+
+ private:
+ std::vector layer_counts_;
+ int64_t layer_counts_sum_{0};
+ std::shared_ptr tree_{nullptr};
+ int seed_{0};
+ int start_sample_layer_{1};
+ std::vector> sampler_vec_;
+ std::vector> layer_ids_;
+};
+
+} // end namespace distributed
+} // end namespace paddle
diff --git a/paddle/fluid/distributed/index_dataset/index_wrapper.cc b/paddle/fluid/distributed/index_dataset/index_wrapper.cc
new file mode 100644
index 0000000000000000000000000000000000000000..99fe4ca0c6d043caef01a867a5acc0d40841ee01
--- /dev/null
+++ b/paddle/fluid/distributed/index_dataset/index_wrapper.cc
@@ -0,0 +1,196 @@
+/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+ http://www.apache.org/licenses/LICENSE-2.0
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
+
+#include
+#include
+#include
+#include
+#include
+#include
+#include "paddle/fluid/framework/io/fs.h"
+
+#include
+#include
+#include "paddle/fluid/distributed/index_dataset/index_wrapper.h"
+
+namespace paddle {
+namespace distributed {
+
+std::shared_ptr IndexWrapper::s_instance_(nullptr);
+
+int TreeIndex::Load(const std::string filename) {
+ int err_no;
+ auto fp = paddle::framework::fs_open_read(filename, &err_no, "");
+ PADDLE_ENFORCE_NE(
+ fp, nullptr,
+ platform::errors::InvalidArgument(
+ "Open file %s failed. Please check whether the file exists.",
+ filename));
+
+ int num = 0;
+ max_id_ = 0;
+ fake_node_.set_id(0);
+ fake_node_.set_is_leaf(false);
+ fake_node_.set_probability(0.0);
+ max_code_ = 0;
+ size_t ret = fread(&num, sizeof(num), 1, fp.get());
+ while (ret == 1 && num > 0) {
+ std::string content(num, '\0');
+ size_t read_num =
+ fread(const_cast(content.data()), 1, num, fp.get());
+ PADDLE_ENFORCE_EQ(
+ read_num, static_cast(num),
+ platform::errors::InvalidArgument(
+ "Read from file: %s failed. Valid Format is "
+ "an integer representing the length of the following string, "
+ "and the string itself.We got an iteger[% d], "
+ "but the following string's length is [%d].",
+ filename, num, read_num));
+
+ KVItem item;
+ PADDLE_ENFORCE_EQ(
+ item.ParseFromString(content), true,
+ platform::errors::InvalidArgument("Parse from file: %s failed. It's "
+ "content can't be parsed by KVItem.",
+ filename));
+
+ if (item.key() == ".tree_meta") {
+ meta_.ParseFromString(item.value());
+ } else {
+ auto code = boost::lexical_cast(item.key());
+ IndexNode node;
+ node.ParseFromString(item.value());
+ PADDLE_ENFORCE_NE(node.id(), 0,
+ platform::errors::InvalidArgument(
+ "Node'id should not be equel to zero."));
+ if (node.is_leaf()) {
+ id_codes_map_[node.id()] = code;
+ }
+ data_[code] = node;
+ if (node.id() > max_id_) {
+ max_id_ = node.id();
+ }
+ if (code > max_code_) {
+ max_code_ = code;
+ }
+ }
+ ret = fread(&num, sizeof(num), 1, fp.get());
+ }
+ total_nodes_num_ = data_.size();
+ max_code_ += 1;
+ return 0;
+}
+
+std::vector TreeIndex::GetNodes(const std::vector& codes) {
+ std::vector nodes;
+ nodes.reserve(codes.size());
+ for (size_t i = 0; i < codes.size(); i++) {
+ if (CheckIsValid(codes[i])) {
+ nodes.push_back(data_.at(codes[i]));
+ } else {
+ nodes.push_back(fake_node_);
+ }
+ }
+ return nodes;
+}
+
+std::vector TreeIndex::GetLayerCodes(int level) {
+ uint64_t level_num = static_cast(std::pow(meta_.branch(), level));
+ uint64_t level_offset = level_num - 1;
+
+ std::vector res;
+ res.reserve(level_num);
+ for (uint64_t i = 0; i < level_num; i++) {
+ auto code = level_offset + i;
+ if (CheckIsValid(code)) {
+ res.push_back(code);
+ }
+ }
+ return res;
+}
+
+std::vector TreeIndex::GetAncestorCodes(
+ const std::vector& ids, int level) {
+ std::vector res;
+ res.reserve(ids.size());
+
+ int cur_level;
+ for (size_t i = 0; i < ids.size(); i++) {
+ if (id_codes_map_.find(ids[i]) == id_codes_map_.end()) {
+ res.push_back(max_code_);
+ } else {
+ auto code = id_codes_map_.at(ids[i]);
+ cur_level = meta_.height() - 1;
+
+ while (level >= 0 && cur_level > level) {
+ code = (code - 1) / meta_.branch();
+ cur_level--;
+ }
+ res.push_back(code);
+ }
+ }
+ return res;
+}
+
+std::vector TreeIndex::GetChildrenCodes(uint64_t ancestor,
+ int level) {
+ auto level_code_num = static_cast(std::pow(meta_.branch(), level));
+ auto code_min = level_code_num - 1;
+ auto code_max = meta_.branch() * level_code_num - 1;
+
+ std::vector parent;
+ parent.push_back(ancestor);
+ std::vector res;
+ size_t p_idx = 0;
+ while (true) {
+ size_t p_size = parent.size();
+ for (; p_idx < p_size; p_idx++) {
+ for (int i = 0; i < meta_.branch(); i++) {
+ auto code = parent[p_idx] * meta_.branch() + i + 1;
+ if (data_.find(code) != data_.end()) parent.push_back(code);
+ }
+ }
+ if ((code_min <= parent[p_idx]) && (parent[p_idx] < code_max)) {
+ break;
+ }
+ }
+
+ return std::vector(parent.begin() + p_idx, parent.end());
+}
+
+std::vector TreeIndex::GetTravelCodes(uint64_t id, int start_level) {
+ std::vector res;
+ PADDLE_ENFORCE_NE(id_codes_map_.find(id), id_codes_map_.end(),
+ paddle::platform::errors::InvalidArgument(
+ "id = %d doesn't exist in Tree.", id));
+ auto code = id_codes_map_.at(id);
+ int level = meta_.height() - 1;
+
+ while (level >= start_level) {
+ res.push_back(code);
+ code = (code - 1) / meta_.branch();
+ level--;
+ }
+ return res;
+}
+
+std::vector TreeIndex::GetAllLeafs() {
+ std::vector res;
+ res.reserve(id_codes_map_.size());
+ for (auto& ite : id_codes_map_) {
+ auto code = ite.second;
+ res.push_back(data_.at(code));
+ }
+ return res;
+}
+
+} // end namespace distributed
+} // end namespace paddle
diff --git a/paddle/fluid/distributed/index_dataset/index_wrapper.h b/paddle/fluid/distributed/index_dataset/index_wrapper.h
new file mode 100644
index 0000000000000000000000000000000000000000..8fb8faf6c84a2d9e1a5e80179a113b8d1ef312c8
--- /dev/null
+++ b/paddle/fluid/distributed/index_dataset/index_wrapper.h
@@ -0,0 +1,120 @@
+/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+ http://www.apache.org/licenses/LICENSE-2.0
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
+
+#pragma once
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include "paddle/fluid/distributed/index_dataset/index_dataset.pb.h"
+#include "paddle/fluid/platform/enforce.h"
+
+namespace paddle {
+namespace distributed {
+
+class Index {
+ public:
+ Index() {}
+ ~Index() {}
+};
+
+class TreeIndex : public Index {
+ public:
+ TreeIndex() {}
+ ~TreeIndex() {}
+
+ int Height() { return meta_.height(); }
+ int Branch() { return meta_.branch(); }
+ uint64_t TotalNodeNums() { return total_nodes_num_; }
+ uint64_t EmbSize() { return max_id_ + 1; }
+ int Load(const std::string path);
+
+ inline bool CheckIsValid(int code) {
+ if (data_.find(code) != data_.end()) {
+ return true;
+ } else {
+ return false;
+ }
+ }
+
+ std::vector GetNodes(const std::vector& codes);
+ std::vector GetLayerCodes(int level);
+ std::vector GetAncestorCodes(const std::vector& ids,
+ int level);
+ std::vector GetChildrenCodes(uint64_t ancestor, int level);
+ std::vector GetTravelCodes(uint64_t id, int start_level);
+ std::vector GetAllLeafs();
+
+ std::unordered_map data_;
+ std::unordered_map id_codes_map_;
+ uint64_t total_nodes_num_;
+ TreeMeta meta_;
+ uint64_t max_id_;
+ uint64_t max_code_;
+ IndexNode fake_node_;
+};
+
+using TreePtr = std::shared_ptr;
+
+class IndexWrapper {
+ public:
+ virtual ~IndexWrapper() {}
+ IndexWrapper() {}
+
+ void clear_tree() { tree_map.clear(); }
+
+ TreePtr get_tree_index(const std::string name) {
+ PADDLE_ENFORCE_NE(tree_map.find(name), tree_map.end(),
+ paddle::platform::errors::InvalidArgument(
+ "tree [%s] doesn't exist. Please insert it firstly "
+ "by API[\' insert_tree_index \'].",
+ name));
+ return tree_map[name];
+ }
+
+ void insert_tree_index(const std::string name, const std::string tree_path) {
+ if (tree_map.find(name) != tree_map.end()) {
+ VLOG(0) << "Tree " << name << " has already existed.";
+ return;
+ }
+ TreePtr tree = std::make_shared();
+ int ret = tree->Load(tree_path);
+ PADDLE_ENFORCE_EQ(ret, 0, paddle::platform::errors::InvalidArgument(
+ "Load tree[%s] from path[%s] failed. Please "
+ "check whether the file exists.",
+ name, tree_path));
+ tree_map.insert(std::pair{name, tree});
+ }
+
+ static std::shared_ptr GetInstancePtr() {
+ if (NULL == s_instance_) {
+ s_instance_.reset(new paddle::distributed::IndexWrapper());
+ }
+ return s_instance_;
+ }
+
+ static IndexWrapper* GetInstance() {
+ if (NULL == s_instance_) {
+ s_instance_.reset(new paddle::distributed::IndexWrapper());
+ }
+ return s_instance_.get();
+ }
+
+ private:
+ static std::shared_ptr s_instance_;
+ std::unordered_map tree_map;
+};
+
+} // end namespace distributed
+} // end namespace paddle
diff --git a/paddle/fluid/distributed/service/CMakeLists.txt b/paddle/fluid/distributed/service/CMakeLists.txt
index bb3f6f1174da9d49a8407ec8db16a5a2aa2a8336..d1f04e26ade7289bcb10988d02de01962a1889ab 100644
--- a/paddle/fluid/distributed/service/CMakeLists.txt
+++ b/paddle/fluid/distributed/service/CMakeLists.txt
@@ -16,6 +16,7 @@ set_source_files_properties(communicator.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUT
set_source_files_properties(service.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(brpc_ps_server.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(brpc_ps_client.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
+set_source_files_properties(ps_local_client.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(brpc_utils.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(heter_server.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
@@ -24,11 +25,13 @@ set_source_files_properties(heter_client.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUT
set_source_files_properties(client.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(ps_client.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(server.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
-
+set_source_files_properties(graph_brpc_server.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
+set_source_files_properties(graph_brpc_client.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_library(brpc_utils SRCS brpc_utils.cc DEPS tensor device_context ${COMMON_DEPS} ${RPC_DEPS})
-cc_library(downpour_server SRCS brpc_ps_server.cc DEPS boost eigen3 table brpc_utils ${RPC_DEPS})
-cc_library(downpour_client SRCS brpc_ps_client.cc DEPS boost eigen3 table brpc_utils ${RPC_DEPS})
+cc_library(downpour_server SRCS graph_brpc_server.cc brpc_ps_server.cc DEPS boost eigen3 table brpc_utils simple_threadpool ${RPC_DEPS})
+cc_library(downpour_client SRCS graph_brpc_client.cc brpc_ps_client.cc
+ps_local_client.cc DEPS boost eigen3 table brpc_utils simple_threadpool ${RPC_DEPS})
cc_library(client SRCS ps_client.cc DEPS downpour_client boost ${RPC_DEPS})
cc_library(server SRCS server.cc DEPS downpour_server boost ${RPC_DEPS})
@@ -38,3 +41,6 @@ cc_library(ps_service SRCS service.cc DEPS communicator client server boost ${RP
cc_library(heter_server SRCS heter_server.cc DEPS brpc_utils ${COMMON_DEPS} ${RPC_DEPS})
cc_library(heter_client SRCS heter_client.cc DEPS brpc_utils ${COMMON_DEPS} ${RPC_DEPS})
+
+set_source_files_properties(graph_py_service.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
+cc_library(graph_py_service SRCS graph_py_service.cc DEPS ps_service)
diff --git a/paddle/fluid/distributed/service/brpc_ps_client.cc b/paddle/fluid/distributed/service/brpc_ps_client.cc
index 163526fe3b28c91f36e2670d1974b520ef3bf66a..a6ad9d08f52fda9bd79b1a1f0eebf1769c855eb3 100644
--- a/paddle/fluid/distributed/service/brpc_ps_client.cc
+++ b/paddle/fluid/distributed/service/brpc_ps_client.cc
@@ -768,8 +768,8 @@ std::future BrpcPsClient::push_global_step(int table_id,
std::future BrpcPsClient::pull_sparse(float **select_values,
size_t table_id,
- const uint64_t *keys,
- size_t num) {
+ const uint64_t *keys, size_t num,
+ bool is_training) {
size_t request_call_num = _server_channels.size();
auto shard_sorted_kvs = std::make_shared<
@@ -837,16 +837,27 @@ std::future BrpcPsClient::pull_sparse(float **select_values,
uint32_t kv_request_count = 0;
size_t sorted_kv_size = sorted_kvs.size();
auto &request_buffer = closure->cntl(i)->request_attachment();
+
+ request_buffer.append((void *)&is_training, sizeof(bool));
+ std::vector keys_counter;
+ keys_counter.reserve(sorted_kv_size);
+
for (size_t kv_idx = 0; kv_idx < sorted_kv_size; ++kv_idx) {
++kv_request_count;
+ uint32_t keys = 1;
last_key = sorted_kvs[kv_idx].first;
request_buffer.append((void *)&last_key, sizeof(uint64_t));
while (kv_idx < sorted_kv_size - 1 &&
last_key == sorted_kvs[kv_idx + 1].first) {
++kv_idx;
+ ++keys;
}
+ keys_counter.push_back(keys);
}
+ request_buffer.append((void *)keys_counter.data(),
+ sizeof(uint32_t) * keys_counter.size());
+
if (kv_request_count == 0) {
closure->Run();
} else {
@@ -869,8 +880,8 @@ std::future BrpcPsClient::send_client2client_msg(
auto promise = std::make_shared>();
std::future fut = promise->get_future();
if (to_client_id >= _client_channels.size()) {
- LOG(FATAL) << "to_client_id is out of range clients, which size is "
- << _client_channels.size();
+ VLOG(0) << "to_client_id is out of range clients, which size is "
+ << _client_channels.size();
promise->set_value(-1);
return fut;
}
@@ -956,7 +967,7 @@ int32_t BrpcPsClient::recv_and_save_table(const uint64_t table_id,
}
auto status = pull_sparse((float **)save_vec.data(), table_id,
- save_key.data(), save_key.size());
+ save_key.data(), save_key.size(), true);
status.wait();
// create lod tensor
diff --git a/paddle/fluid/distributed/service/brpc_ps_client.h b/paddle/fluid/distributed/service/brpc_ps_client.h
index 8f9d2653864d1c7fd1801632a6c84edb1bc04ccf..5192356e4b5e574de385478c57a7b7cedb49988a 100644
--- a/paddle/fluid/distributed/service/brpc_ps_client.h
+++ b/paddle/fluid/distributed/service/brpc_ps_client.h
@@ -148,7 +148,8 @@ class BrpcPsClient : public PSClient {
virtual std::future pull_sparse(float **select_values,
size_t table_id,
- const uint64_t *keys, size_t num);
+ const uint64_t *keys, size_t num,
+ bool is_training);
virtual std::future print_table_stat(uint32_t table_id);
@@ -170,9 +171,22 @@ class BrpcPsClient : public PSClient {
virtual int32_t recv_and_save_table(const uint64_t table_id,
const std::string &path);
- private:
+ protected:
+ virtual size_t get_server_nums() { return _server_channels.size(); }
+ inline brpc::Channel *get_sparse_channel(size_t server_id) {
+ return _server_channels[server_id][0].get();
+ }
+ inline brpc::Channel *get_dense_channel(size_t server_id) {
+ return _server_channels[server_id][1].get();
+ }
+ inline brpc::Channel *get_cmd_channel(size_t server_id) {
+ return _server_channels[server_id][2].get();
+ }
virtual int32_t initialize() override;
+ private:
+ // virtual int32_t initialize() override;
+
inline uint32_t dense_dim_per_shard(uint32_t dense_dim_total,
uint32_t shard_num) {
return dense_dim_total / shard_num + 1;
@@ -184,16 +198,6 @@ class BrpcPsClient : public PSClient {
std::future send_save_cmd(uint32_t table_id, int cmd_id,
const std::vector ¶m);
- inline brpc::Channel *get_sparse_channel(size_t server_id) {
- return _server_channels[server_id][0].get();
- }
- inline brpc::Channel *get_dense_channel(size_t server_id) {
- return _server_channels[server_id][1].get();
- }
- inline brpc::Channel *get_cmd_channel(size_t server_id) {
- return _server_channels[server_id][2].get();
- }
-
bool _running = false;
bool _flushing = false;
std::atomic _async_call_num; //异步请求计数
@@ -220,8 +224,6 @@ class BrpcPsClient : public PSClient {
size_t num,
void *done) override;
- virtual size_t get_server_nums() { return _server_channels.size(); }
-
private:
int32_t start_client_service();
diff --git a/paddle/fluid/distributed/service/brpc_ps_server.cc b/paddle/fluid/distributed/service/brpc_ps_server.cc
index 8400e669182d670b892dc2eb55492a92ee919ae5..a1440260bf2e77093bb937e62b13b54ad06a3e64 100644
--- a/paddle/fluid/distributed/service/brpc_ps_server.cc
+++ b/paddle/fluid/distributed/service/brpc_ps_server.cc
@@ -14,6 +14,8 @@
#include "paddle/fluid/distributed/service/brpc_ps_server.h"
#include // NOLINT
+#include "butil/object_pool.h"
+#include "paddle/fluid/distributed/table/depends/sparse_utils.h"
#include "paddle/fluid/distributed/table/table.h"
#include "paddle/fluid/framework/archive.h"
#include "paddle/fluid/platform/profiler.h"
@@ -60,7 +62,8 @@ uint64_t BrpcPsServer::start(const std::string &ip, uint32_t port) {
std::unique_lock lock(mutex_);
std::string ip_port = ip + ":" + std::to_string(port);
- VLOG(3) << "server of rank " << _rank << " starts at " << ip_port;
+ VLOG(0) << "running server with rank id: " << _rank
+ << ", endpoint: " << ip_port;
brpc::ServerOptions options;
int num_threads = std::thread::hardware_concurrency();
@@ -194,12 +197,13 @@ int32_t BrpcPsService::pull_dense(Table *table, const PsRequestMessage &request,
return 0;
}
- std::vector res_data;
- res_data.resize(num * table->value_accesor()->select_size() / sizeof(float));
- table->pull_dense(res_data.data(), num);
+ auto res_data = butil::get_object>();
+ res_data->resize(num * table->value_accesor()->select_size() / sizeof(float));
+ table->pull_dense(res_data->data(), num);
- cntl->response_attachment().append((char *)res_data.data(),
- res_data.size() * sizeof(float));
+ cntl->response_attachment().append((char *)(res_data->data()),
+ res_data->size() * sizeof(float));
+ butil::return_object(res_data);
return 0;
}
@@ -336,35 +340,42 @@ int32_t BrpcPsService::pull_sparse(Table *table,
brpc::Controller *cntl) {
platform::RecordEvent record_event("PsService->pull_sparse");
CHECK_TABLE_EXIST(table, request, response)
- thread_local std::string push_sparse_request_buffer;
+
auto &req_io_buffer = cntl->request_attachment();
auto req_buffer_size = req_io_buffer.size();
+
if (req_buffer_size < 1) {
set_response_code(response, -1, "req attachment is empty");
return 0;
}
+
if (request.params_size() < 1) {
set_response_code(response, -1,
"PsRequestMessage.params is requeired at "
"least 1 for num of sparse_key");
return 0;
}
+
uint32_t num = *(uint32_t *)(request.params(0).c_str());
- push_sparse_request_buffer.resize(0);
- push_sparse_request_buffer.reserve(req_buffer_size);
- const char *data = (const char *)cntl->request_attachment().fetch(
- const_cast(push_sparse_request_buffer.data()), req_buffer_size);
- /*
- Attachment Content:
- |---keysData---|
- |---8*{num}B---|
- */
- const uint64_t *keys = (const uint64_t *)data;
- std::vector res_data;
- res_data.resize(num * table->value_accesor()->select_size() / sizeof(float));
- table->pull_sparse(res_data.data(), keys, num);
- cntl->response_attachment().append((char *)res_data.data(),
- res_data.size() * sizeof(float));
+ auto dim = table->value_accesor()->select_dim();
+
+ thread_local std::string req_buffer;
+ req_buffer.reserve(req_buffer_size);
+
+ const void *data = cntl->request_attachment().fetch(
+ const_cast(req_buffer.data()), req_buffer_size);
+
+ auto value = PullSparseValue(num, dim);
+
+ value.DeserializeFromBytes(const_cast(data));
+
+ auto res_data = butil::get_object>();
+ res_data->resize(num * dim);
+ table->pull_sparse(res_data->data(), value);
+
+ cntl->response_attachment().append((char *)(res_data->data()),
+ res_data->size() * sizeof(float));
+ butil::return_object(res_data);
return 0;
}
@@ -538,7 +549,7 @@ int32_t BrpcPsService::stop_server(Table *table,
auto *p_server = _server;
std::thread t_stop([p_server]() {
p_server->stop();
- LOG(INFO) << "Server Stoped";
+ VLOG(3) << "Server Stoped";
});
t_stop.detach();
return 0;
diff --git a/paddle/fluid/distributed/service/brpc_utils.cc b/paddle/fluid/distributed/service/brpc_utils.cc
index 096718768149c574fd57b91396879d7bec5d37e0..a356b77e73733ed9b657a7603adf57c5228bf3c5 100644
--- a/paddle/fluid/distributed/service/brpc_utils.cc
+++ b/paddle/fluid/distributed/service/brpc_utils.cc
@@ -324,7 +324,7 @@ std::string GetIntTypeEndpoint(const std::string& ip, const uint32_t& port) {
while (hp->h_addr_list[i] != NULL) {
int_ip = inet_ntoa(*(struct in_addr*)hp->h_addr_list[i]);
- VLOG(0) << "Brpc Get host by name, host:" << ip << " -> ip: " << int_ip;
+ VLOG(3) << "Brpc Get host by name, host:" << ip << " -> ip: " << int_ip;
break;
}
diff --git a/paddle/fluid/distributed/service/communicator.cc b/paddle/fluid/distributed/service/communicator.cc
index 8699719e5cdcc8f40cf26fc90c17ad52849804d3..3d5ab8e16d90202d2365c14f764f5e0f53929b68 100644
--- a/paddle/fluid/distributed/service/communicator.cc
+++ b/paddle/fluid/distributed/service/communicator.cc
@@ -320,9 +320,11 @@ void Communicator::RpcRecvSparse(const std::string &varname, int table_id,
push_g_vec.push_back(tensor->data() + i * dim);
}
+ bool training = true;
+
auto status = _worker_ptr->pull_sparse(
(float **)push_g_vec.data(), table_id, // NOLINT
- sparse_push_keys.data(), sparse_push_keys.size());
+ sparse_push_keys.data(), sparse_push_keys.size(), training);
status.wait();
return;
}
diff --git a/paddle/fluid/distributed/service/communicator.h b/paddle/fluid/distributed/service/communicator.h
index 043fe9d83dfc53aaa5d13ef1f12745836129aaa0..fa60cab2b58779ede16cc51971277130bcaca909 100644
--- a/paddle/fluid/distributed/service/communicator.h
+++ b/paddle/fluid/distributed/service/communicator.h
@@ -310,6 +310,8 @@ class Communicator {
return _worker_ptr;
}
+ RecvCtxMap &GetRecvCtxMap() { return recv_varname_to_ctx_; }
+
std::shared_ptr _worker_ptr; // pointer to worker
protected:
diff --git a/paddle/fluid/distributed/service/env.h b/paddle/fluid/distributed/service/env.h
index 901aba0ad90c49c7403862997830bed7e0950dc0..ca395a776afd4e2ee53e0aeaebb94494d4f4e6a6 100644
--- a/paddle/fluid/distributed/service/env.h
+++ b/paddle/fluid/distributed/service/env.h
@@ -39,7 +39,7 @@ struct PSHost {
// |---ip---|---port---|--rank--|
// |-32bit--|--20bit---|--12bit-|
- // for pslib
+
uint64_t serialize_to_uint64() {
uint64_t host_label = 0;
host_label = inet_addr(ip.c_str());
@@ -175,14 +175,12 @@ class PSEnvironment {
host.ip = ip;
host.port = port;
host.rank = rank;
- if (sign_set.count(rank) > 0) {
- LOG(WARNING) << "ps-host :" << host.ip << ":" << host.port
- << ", rank:" << host.rank
- << " already register, ignore register";
- } else {
+
+ if (sign_set.count(rank) == 0) {
host_list.push_back(host);
sign_set.insert(rank);
}
+
return 0;
}
diff --git a/paddle/fluid/distributed/service/graph_brpc_client.cc b/paddle/fluid/distributed/service/graph_brpc_client.cc
new file mode 100644
index 0000000000000000000000000000000000000000..eafb4d596cc1671db26189b84ea9d0c0c31ea398
--- /dev/null
+++ b/paddle/fluid/distributed/service/graph_brpc_client.cc
@@ -0,0 +1,332 @@
+// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "paddle/fluid/distributed/service/graph_brpc_client.h"
+#include
+#include
+#include
+#include
+#include
+#include
+#include "Eigen/Dense"
+#include "paddle/fluid/distributed/service/brpc_ps_client.h"
+#include "paddle/fluid/distributed/table/table.h"
+#include "paddle/fluid/framework/archive.h"
+#include "paddle/fluid/string/string_helper.h"
+namespace paddle {
+namespace distributed {
+
+void GraphPsService_Stub::service(
+ ::google::protobuf::RpcController *controller,
+ const ::paddle::distributed::PsRequestMessage *request,
+ ::paddle::distributed::PsResponseMessage *response,
+ ::google::protobuf::Closure *done) {
+ if (graph_service != NULL && local_channel == channel()) {
+ // VLOG(0)<<"use local";
+ task_pool->enqueue([this, controller, request, response, done]() -> int {
+ this->graph_service->service(controller, request, response, done);
+ return 0;
+ });
+ } else {
+ // VLOG(0)<<"use server";
+ PsService_Stub::service(controller, request, response, done);
+ }
+}
+
+int GraphBrpcClient::get_server_index_by_id(uint64_t id) {
+ int shard_num = get_shard_num();
+ int shard_per_server = shard_num % server_size == 0
+ ? shard_num / server_size
+ : shard_num / server_size + 1;
+ return id % shard_num / shard_per_server;
+}
+
+std::future GraphBrpcClient::get_node_feat(
+ const uint32_t &table_id, const std::vector &node_ids,
+ const std::vector &feature_names,
+ std::vector> &res) {
+ std::vector request2server;
+ std::vector server2request(server_size, -1);
+ for (int query_idx = 0; query_idx < node_ids.size(); ++query_idx) {
+ int server_index = get_server_index_by_id(node_ids[query_idx]);
+ if (server2request[server_index] == -1) {
+ server2request[server_index] = request2server.size();
+ request2server.push_back(server_index);
+ }
+ }
+ size_t request_call_num = request2server.size();
+ std::vector> node_id_buckets(request_call_num);
+ std::vector> query_idx_buckets(request_call_num);
+ for (int query_idx = 0; query_idx < node_ids.size(); ++query_idx) {
+ int server_index = get_server_index_by_id(node_ids[query_idx]);
+ int request_idx = server2request[server_index];
+ node_id_buckets[request_idx].push_back(node_ids[query_idx]);
+ query_idx_buckets[request_idx].push_back(query_idx);
+ }
+
+ DownpourBrpcClosure *closure = new DownpourBrpcClosure(
+ request_call_num,
+ [&, node_id_buckets, query_idx_buckets, request_call_num](void *done) {
+ int ret = 0;
+ auto *closure = (DownpourBrpcClosure *)done;
+ int fail_num = 0;
+ for (int request_idx = 0; request_idx < request_call_num;
+ ++request_idx) {
+ if (closure->check_response(request_idx,
+ PS_GRAPH_SAMPLE_NEIGHBOORS) != 0) {
+ ++fail_num;
+ } else {
+ auto &res_io_buffer =
+ closure->cntl(request_idx)->response_attachment();
+ butil::IOBufBytesIterator io_buffer_itr(res_io_buffer);
+ size_t bytes_size = io_buffer_itr.bytes_left();
+ std::unique_ptr buffer_wrapper(new char[bytes_size]);
+ char *buffer = buffer_wrapper.get();
+ io_buffer_itr.copy_and_forward((void *)(buffer), bytes_size);
+
+ for (size_t feat_idx = 0; feat_idx < feature_names.size();
+ ++feat_idx) {
+ for (size_t node_idx = 0;
+ node_idx < query_idx_buckets.at(request_idx).size();
+ ++node_idx) {
+ int query_idx = query_idx_buckets.at(request_idx).at(node_idx);
+ size_t feat_len = *(size_t *)(buffer);
+ buffer += sizeof(size_t);
+ auto feature = std::string(buffer, feat_len);
+ res[feat_idx][query_idx] = feature;
+ buffer += feat_len;
+ }
+ }
+ }
+ if (fail_num == request_call_num) {
+ ret = -1;
+ }
+ }
+ closure->set_promise_value(ret);
+ });
+
+ auto promise = std::make_shared>();
+ closure->add_promise(promise);
+ std::future fut = promise->get_future();
+
+ for (int request_idx = 0; request_idx < request_call_num; ++request_idx) {
+ int server_index = request2server[request_idx];
+ closure->request(request_idx)->set_cmd_id(PS_GRAPH_GET_NODE_FEAT);
+ closure->request(request_idx)->set_table_id(table_id);
+ closure->request(request_idx)->set_client_id(_client_id);
+ size_t node_num = node_id_buckets[request_idx].size();
+
+ closure->request(request_idx)
+ ->add_params((char *)node_id_buckets[request_idx].data(),
+ sizeof(uint64_t) * node_num);
+ std::string joint_feature_name =
+ paddle::string::join_strings(feature_names, '\t');
+ closure->request(request_idx)
+ ->add_params(joint_feature_name.c_str(), joint_feature_name.size());
+
+ GraphPsService_Stub rpc_stub =
+ getServiceStub(get_cmd_channel(server_index));
+ closure->cntl(request_idx)->set_log_id(butil::gettimeofday_ms());
+ rpc_stub.service(closure->cntl(request_idx), closure->request(request_idx),
+ closure->response(request_idx), closure);
+ }
+
+ return fut;
+}
+// char* &buffer,int &actual_size
+std::future GraphBrpcClient::batch_sample_neighboors(
+ uint32_t table_id, std::vector node_ids, int sample_size,
+ std::vector>> &res) {
+ std::vector request2server;
+ std::vector server2request(server_size, -1);
+ res.clear();
+ for (int query_idx = 0; query_idx < node_ids.size(); ++query_idx) {
+ int server_index = get_server_index_by_id(node_ids[query_idx]);
+ if (server2request[server_index] == -1) {
+ server2request[server_index] = request2server.size();
+ request2server.push_back(server_index);
+ }
+ res.push_back(std::vector>());
+ }
+ size_t request_call_num = request2server.size();
+ std::vector> node_id_buckets(request_call_num);
+ std::vector> query_idx_buckets(request_call_num);
+ for (int query_idx = 0; query_idx < node_ids.size(); ++query_idx) {
+ int server_index = get_server_index_by_id(node_ids[query_idx]);
+ int request_idx = server2request[server_index];
+ node_id_buckets[request_idx].push_back(node_ids[query_idx]);
+ query_idx_buckets[request_idx].push_back(query_idx);
+ }
+
+ DownpourBrpcClosure *closure = new DownpourBrpcClosure(
+ request_call_num,
+ [&, node_id_buckets, query_idx_buckets, request_call_num](void *done) {
+ int ret = 0;
+ auto *closure = (DownpourBrpcClosure *)done;
+ int fail_num = 0;
+ for (int request_idx = 0; request_idx < request_call_num;
+ ++request_idx) {
+ if (closure->check_response(request_idx,
+ PS_GRAPH_SAMPLE_NEIGHBOORS) != 0) {
+ ++fail_num;
+ } else {
+ auto &res_io_buffer =
+ closure->cntl(request_idx)->response_attachment();
+ butil::IOBufBytesIterator io_buffer_itr(res_io_buffer);
+ size_t bytes_size = io_buffer_itr.bytes_left();
+ std::unique_ptr buffer_wrapper(new char[bytes_size]);
+ char *buffer = buffer_wrapper.get();
+ io_buffer_itr.copy_and_forward((void *)(buffer), bytes_size);
+
+ size_t node_num = *(size_t *)buffer;
+ int *actual_sizes = (int *)(buffer + sizeof(size_t));
+ char *node_buffer =
+ buffer + sizeof(size_t) + sizeof(int) * node_num;
+
+ int offset = 0;
+ for (size_t node_idx = 0; node_idx < node_num; ++node_idx) {
+ int query_idx = query_idx_buckets.at(request_idx).at(node_idx);
+ int actual_size = actual_sizes[node_idx];
+ int start = 0;
+ while (start < actual_size) {
+ res[query_idx].push_back(
+ {*(uint64_t *)(node_buffer + offset + start),
+ *(float *)(node_buffer + offset + start +
+ GraphNode::id_size)});
+ start += GraphNode::id_size + GraphNode::weight_size;
+ }
+ offset += actual_size;
+ }
+ }
+ if (fail_num == request_call_num) {
+ ret = -1;
+ }
+ }
+ closure->set_promise_value(ret);
+ });
+
+ auto promise = std::make_shared>();
+ closure->add_promise(promise);
+ std::future fut = promise->get_future();
+
+ for (int request_idx = 0; request_idx < request_call_num; ++request_idx) {
+ int server_index = request2server[request_idx];
+ closure->request(request_idx)->set_cmd_id(PS_GRAPH_SAMPLE_NEIGHBOORS);
+ closure->request(request_idx)->set_table_id(table_id);
+ closure->request(request_idx)->set_client_id(_client_id);
+ size_t node_num = node_id_buckets[request_idx].size();
+
+ closure->request(request_idx)
+ ->add_params((char *)node_id_buckets[request_idx].data(),
+ sizeof(uint64_t) * node_num);
+ closure->request(request_idx)
+ ->add_params((char *)&sample_size, sizeof(int));
+ // PsService_Stub rpc_stub(get_cmd_channel(server_index));
+ GraphPsService_Stub rpc_stub =
+ getServiceStub(get_cmd_channel(server_index));
+ closure->cntl(request_idx)->set_log_id(butil::gettimeofday_ms());
+ rpc_stub.service(closure->cntl(request_idx), closure->request(request_idx),
+ closure->response(request_idx), closure);
+ }
+
+ return fut;
+}
+std::future GraphBrpcClient::random_sample_nodes(
+ uint32_t table_id, int server_index, int sample_size,
+ std::vector &ids) {
+ DownpourBrpcClosure *closure = new DownpourBrpcClosure(1, [&](void *done) {
+ int ret = 0;
+ auto *closure = (DownpourBrpcClosure *)done;
+ if (closure->check_response(0, PS_GRAPH_SAMPLE_NODES) != 0) {
+ ret = -1;
+ } else {
+ auto &res_io_buffer = closure->cntl(0)->response_attachment();
+ butil::IOBufBytesIterator io_buffer_itr(res_io_buffer);
+ size_t bytes_size = io_buffer_itr.bytes_left();
+ char buffer[bytes_size];
+ auto size = io_buffer_itr.copy_and_forward((void *)(buffer), bytes_size);
+ int index = 0;
+ while (index < bytes_size) {
+ ids.push_back(*(uint64_t *)(buffer + index));
+ index += GraphNode::id_size;
+ }
+ }
+ closure->set_promise_value(ret);
+ });
+ auto promise = std::make_shared>();
+ closure->add_promise(promise);
+ std::future fut = promise->get_future();
+ ;
+ closure->request(0)->set_cmd_id(PS_GRAPH_SAMPLE_NODES);
+ closure->request(0)->set_table_id(table_id);
+ closure->request(0)->set_client_id(_client_id);
+ closure->request(0)->add_params((char *)&sample_size, sizeof(int));
+ ;
+ // PsService_Stub rpc_stub(get_cmd_channel(server_index));
+ GraphPsService_Stub rpc_stub = getServiceStub(get_cmd_channel(server_index));
+ closure->cntl(0)->set_log_id(butil::gettimeofday_ms());
+ rpc_stub.service(closure->cntl(0), closure->request(0), closure->response(0),
+ closure);
+ return fut;
+}
+std::future GraphBrpcClient::pull_graph_list(
+ uint32_t table_id, int server_index, int start, int size, int step,
+ std::vector &res) {
+ DownpourBrpcClosure *closure = new DownpourBrpcClosure(1, [&](void *done) {
+ int ret = 0;
+ auto *closure = (DownpourBrpcClosure *)done;
+ if (closure->check_response(0, PS_PULL_GRAPH_LIST) != 0) {
+ ret = -1;
+ } else {
+ auto &res_io_buffer = closure->cntl(0)->response_attachment();
+ butil::IOBufBytesIterator io_buffer_itr(res_io_buffer);
+ size_t bytes_size = io_buffer_itr.bytes_left();
+ char buffer[bytes_size];
+ io_buffer_itr.copy_and_forward((void *)(buffer), bytes_size);
+ int index = 0;
+ while (index < bytes_size) {
+ FeatureNode node;
+ node.recover_from_buffer(buffer + index);
+ index += node.get_size(false);
+ res.push_back(node);
+ }
+ }
+ closure->set_promise_value(ret);
+ });
+ auto promise = std::make_shared>();
+ closure->add_promise(promise);
+ std::future fut = promise->get_future();
+ closure->request(0)->set_cmd_id(PS_PULL_GRAPH_LIST);
+ closure->request(0)->set_table_id(table_id);
+ closure->request(0)->set_client_id(_client_id);
+ closure->request(0)->add_params((char *)&start, sizeof(int));
+ closure->request(0)->add_params((char *)&size, sizeof(int));
+ closure->request(0)->add_params((char *)&step, sizeof(int));
+ // PsService_Stub rpc_stub(get_cmd_channel(server_index));
+ GraphPsService_Stub rpc_stub = getServiceStub(get_cmd_channel(server_index));
+ closure->cntl(0)->set_log_id(butil::gettimeofday_ms());
+ rpc_stub.service(closure->cntl(0), closure->request(0), closure->response(0),
+ closure);
+ return fut;
+}
+int32_t GraphBrpcClient::initialize() {
+ // set_shard_num(_config.shard_num());
+ BrpcPsClient::initialize();
+ server_size = get_server_nums();
+ graph_service = NULL;
+ local_channel = NULL;
+ return 0;
+}
+}
+}
diff --git a/paddle/fluid/distributed/service/graph_brpc_client.h b/paddle/fluid/distributed/service/graph_brpc_client.h
new file mode 100644
index 0000000000000000000000000000000000000000..4e6775a4bedaf1a4028fe483f58be818ef1e3581
--- /dev/null
+++ b/paddle/fluid/distributed/service/graph_brpc_client.h
@@ -0,0 +1,105 @@
+// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#pragma once
+
+#include
+#include
+#include
+#include
+
+#include
+#include "ThreadPool.h"
+#include "brpc/channel.h"
+#include "brpc/controller.h"
+#include "brpc/server.h"
+#include "paddle/fluid/distributed/service/brpc_ps_client.h"
+#include "paddle/fluid/distributed/service/graph_brpc_server.h"
+#include "paddle/fluid/distributed/service/ps_client.h"
+#include "paddle/fluid/distributed/table/table.h"
+#include "paddle/fluid/framework/lod_tensor.h"
+#include "paddle/fluid/framework/scope.h"
+#include "paddle/fluid/framework/tensor_util.h"
+
+namespace paddle {
+namespace distributed {
+
+class GraphPsService_Stub : public PsService_Stub {
+ public:
+ GraphPsService_Stub(::google::protobuf::RpcChannel* channel,
+ ::google::protobuf::RpcChannel* local_channel = NULL,
+ GraphBrpcService* service = NULL, int thread_num = 1)
+ : PsService_Stub(channel) {
+ this->local_channel = local_channel;
+ this->graph_service = service;
+ task_pool.reset(new ::ThreadPool(thread_num));
+ }
+ virtual ~GraphPsService_Stub() {}
+
+ // implements PsService ------------------------------------------
+ GraphBrpcService* graph_service;
+ std::shared_ptr<::ThreadPool> task_pool;
+ ::google::protobuf::RpcChannel* local_channel;
+ GOOGLE_DISALLOW_EVIL_CONSTRUCTORS(GraphPsService_Stub);
+ void service(::google::protobuf::RpcController* controller,
+ const ::paddle::distributed::PsRequestMessage* request,
+ ::paddle::distributed::PsResponseMessage* response,
+ ::google::protobuf::Closure* done);
+};
+class GraphBrpcClient : public BrpcPsClient {
+ public:
+ GraphBrpcClient() {}
+ virtual ~GraphBrpcClient() {}
+ // given a batch of nodes, sample graph_neighboors for each of them
+ virtual std::future batch_sample_neighboors(
+ uint32_t table_id, std::vector node_ids, int sample_size,
+ std::vector>>& res);
+
+ virtual std::future pull_graph_list(uint32_t table_id,
+ int server_index, int start,
+ int size, int step,
+ std::vector& res);
+ virtual std::future random_sample_nodes(uint32_t table_id,
+ int server_index,
+ int sample_size,
+ std::vector& ids);
+ virtual std::future get_node_feat(
+ const uint32_t& table_id, const std::vector& node_ids,
+ const std::vector& feature_names,
+ std::vector>& res);
+ virtual int32_t initialize();
+ int get_shard_num() { return shard_num; }
+ void set_shard_num(int shard_num) { this->shard_num = shard_num; }
+ int get_server_index_by_id(uint64_t id);
+ void set_local_channel(int index) {
+ this->local_channel = get_cmd_channel(index);
+ }
+ void set_local_graph_service(GraphBrpcService* graph_service) {
+ this->graph_service = graph_service;
+ }
+ GraphPsService_Stub getServiceStub(::google::protobuf::RpcChannel* channel,
+ int thread_num = 1) {
+ return GraphPsService_Stub(channel, local_channel, graph_service,
+ thread_num);
+ }
+
+ private:
+ int shard_num;
+ size_t server_size;
+ ::google::protobuf::RpcChannel* local_channel;
+ GraphBrpcService* graph_service;
+};
+
+} // namespace distributed
+} // namespace paddle
diff --git a/paddle/fluid/distributed/service/graph_brpc_server.cc b/paddle/fluid/distributed/service/graph_brpc_server.cc
new file mode 100644
index 0000000000000000000000000000000000000000..bdd926278b624b9e9bfdf19a4f293784bef6e28f
--- /dev/null
+++ b/paddle/fluid/distributed/service/graph_brpc_server.cc
@@ -0,0 +1,348 @@
+// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "paddle/fluid/distributed/service/graph_brpc_server.h"
+#include "paddle/fluid/distributed/service/brpc_ps_server.h"
+
+#include // NOLINT
+#include "butil/endpoint.h"
+#include "iomanip"
+#include "paddle/fluid/distributed/service/brpc_ps_client.h"
+#include "paddle/fluid/framework/archive.h"
+#include "paddle/fluid/platform/profiler.h"
+namespace paddle {
+namespace distributed {
+
+int32_t GraphBrpcServer::initialize() {
+ auto &service_config = _config.downpour_server_param().service_param();
+ if (!service_config.has_service_class()) {
+ LOG(ERROR) << "miss service_class in ServerServiceParameter";
+ return -1;
+ }
+ auto *service =
+ CREATE_PSCORE_CLASS(PsBaseService, service_config.service_class());
+ if (service == NULL) {
+ LOG(ERROR) << "service is unregistered, service_name:"
+ << service_config.service_class();
+ return -1;
+ }
+
+ _service.reset(service);
+ if (service->configure(this) != 0 || service->initialize() != 0) {
+ LOG(ERROR) << "service initialize failed, service_name:"
+ << service_config.service_class();
+ return -1;
+ }
+ if (_server.AddService(service, brpc::SERVER_DOESNT_OWN_SERVICE) != 0) {
+ LOG(ERROR) << "service add to brpc failed, service:"
+ << service_config.service_class();
+ return -1;
+ }
+ return 0;
+}
+
+uint64_t GraphBrpcServer::start(const std::string &ip, uint32_t port) {
+ std::unique_lock lock(mutex_);
+
+ std::string ip_port = ip + ":" + std::to_string(port);
+ VLOG(3) << "server of rank " << _rank << " starts at " << ip_port;
+ brpc::ServerOptions options;
+
+ int num_threads = std::thread::hardware_concurrency();
+ auto trainers = _environment->get_trainers();
+ options.num_threads = trainers > num_threads ? trainers : num_threads;
+
+ if (_server.Start(ip_port.c_str(), &options) != 0) {
+ LOG(ERROR) << "GraphBrpcServer start failed, ip_port=" << ip_port;
+ return 0;
+ }
+ _environment->registe_ps_server(ip, port, _rank);
+ return 0;
+}
+
+int32_t GraphBrpcServer::port() { return _server.listen_address().port; }
+
+int32_t GraphBrpcService::initialize() {
+ _is_initialize_shard_info = false;
+ _service_handler_map[PS_STOP_SERVER] = &GraphBrpcService::stop_server;
+ _service_handler_map[PS_LOAD_ONE_TABLE] = &GraphBrpcService::load_one_table;
+ _service_handler_map[PS_LOAD_ALL_TABLE] = &GraphBrpcService::load_all_table;
+
+ _service_handler_map[PS_PRINT_TABLE_STAT] =
+ &GraphBrpcService::print_table_stat;
+ _service_handler_map[PS_BARRIER] = &GraphBrpcService::barrier;
+ _service_handler_map[PS_START_PROFILER] = &GraphBrpcService::start_profiler;
+ _service_handler_map[PS_STOP_PROFILER] = &GraphBrpcService::stop_profiler;
+
+ _service_handler_map[PS_PULL_GRAPH_LIST] = &GraphBrpcService::pull_graph_list;
+ _service_handler_map[PS_GRAPH_SAMPLE_NEIGHBOORS] =
+ &GraphBrpcService::graph_random_sample_neighboors;
+ _service_handler_map[PS_GRAPH_SAMPLE_NODES] =
+ &GraphBrpcService::graph_random_sample_nodes;
+ _service_handler_map[PS_GRAPH_GET_NODE_FEAT] =
+ &GraphBrpcService::graph_get_node_feat;
+
+ // shard初始化,server启动后才可从env获取到server_list的shard信息
+ initialize_shard_info();
+
+ return 0;
+}
+
+#define CHECK_TABLE_EXIST(table, request, response) \
+ if (table == NULL) { \
+ std::string err_msg("table not found with table_id:"); \
+ err_msg.append(std::to_string(request.table_id())); \
+ set_response_code(response, -1, err_msg.c_str()); \
+ return -1; \
+ }
+
+int32_t GraphBrpcService::initialize_shard_info() {
+ if (!_is_initialize_shard_info) {
+ std::lock_guard guard(_initialize_shard_mutex);
+ if (_is_initialize_shard_info) {
+ return 0;
+ }
+ size_t shard_num = _server->environment()->get_ps_servers().size();
+ auto &table_map = *(_server->table());
+ for (auto itr : table_map) {
+ itr.second->set_shard(_rank, shard_num);
+ }
+ _is_initialize_shard_info = true;
+ }
+ return 0;
+}
+
+void GraphBrpcService::service(google::protobuf::RpcController *cntl_base,
+ const PsRequestMessage *request,
+ PsResponseMessage *response,
+ google::protobuf::Closure *done) {
+ brpc::ClosureGuard done_guard(done);
+ std::string log_label("ReceiveCmd-");
+ if (!request->has_table_id()) {
+ set_response_code(*response, -1, "PsRequestMessage.tabel_id is required");
+ return;
+ }
+
+ response->set_err_code(0);
+ response->set_err_msg("");
+ auto *table = _server->table(request->table_id());
+ brpc::Controller *cntl = static_cast(cntl_base);
+ auto itr = _service_handler_map.find(request->cmd_id());
+ if (itr == _service_handler_map.end()) {
+ std::string err_msg(
+ "undefined cmd_id, should match PsCmdID in ps.proto, cmd_id:");
+ err_msg.append(std::to_string(request->cmd_id()));
+ set_response_code(*response, -1, err_msg.c_str());
+ return;
+ }
+ serviceFunc handler_func = itr->second;
+ int service_ret = (this->*handler_func)(table, *request, *response, cntl);
+ if (service_ret != 0) {
+ response->set_err_code(service_ret);
+ response->set_err_msg("server internal error");
+ }
+}
+
+int32_t GraphBrpcService::barrier(Table *table, const PsRequestMessage &request,
+ PsResponseMessage &response,
+ brpc::Controller *cntl) {
+ CHECK_TABLE_EXIST(table, request, response)
+
+ if (request.params_size() < 1) {
+ set_response_code(response, -1,
+ "PsRequestMessage.params is requeired at "
+ "least 1 for num of sparse_key");
+ return 0;
+ }
+
+ auto trainer_id = request.client_id();
+ auto barrier_type = request.params(0);
+ table->barrier(trainer_id, barrier_type);
+ return 0;
+}
+
+int32_t GraphBrpcService::print_table_stat(Table *table,
+ const PsRequestMessage &request,
+ PsResponseMessage &response,
+ brpc::Controller *cntl) {
+ CHECK_TABLE_EXIST(table, request, response)
+ std::pair ret = table->print_table_stat();
+ paddle::framework::BinaryArchive ar;
+ ar << ret.first << ret.second;
+ std::string table_info(ar.Buffer(), ar.Length());
+ response.set_data(table_info);
+
+ return 0;
+}
+
+int32_t GraphBrpcService::load_one_table(Table *table,
+ const PsRequestMessage &request,
+ PsResponseMessage &response,
+ brpc::Controller *cntl) {
+ CHECK_TABLE_EXIST(table, request, response)
+ if (request.params_size() < 2) {
+ set_response_code(
+ response, -1,
+ "PsRequestMessage.datas is requeired at least 2 for path & load_param");
+ return -1;
+ }
+ if (table->load(request.params(0), request.params(1)) != 0) {
+ set_response_code(response, -1, "table load failed");
+ return -1;
+ }
+ return 0;
+}
+
+int32_t GraphBrpcService::load_all_table(Table *table,
+ const PsRequestMessage &request,
+ PsResponseMessage &response,
+ brpc::Controller *cntl) {
+ auto &table_map = *(_server->table());
+ for (auto &itr : table_map) {
+ if (load_one_table(itr.second.get(), request, response, cntl) != 0) {
+ LOG(ERROR) << "load table[" << itr.first << "] failed";
+ return -1;
+ }
+ }
+ return 0;
+}
+
+int32_t GraphBrpcService::stop_server(Table *table,
+ const PsRequestMessage &request,
+ PsResponseMessage &response,
+ brpc::Controller *cntl) {
+ GraphBrpcServer *p_server = (GraphBrpcServer *)_server;
+ std::thread t_stop([p_server]() {
+ p_server->stop();
+ LOG(INFO) << "Server Stoped";
+ });
+ p_server->export_cv()->notify_all();
+ t_stop.detach();
+ return 0;
+}
+
+int32_t GraphBrpcService::stop_profiler(Table *table,
+ const PsRequestMessage &request,
+ PsResponseMessage &response,
+ brpc::Controller *cntl) {
+ platform::DisableProfiler(platform::EventSortingKey::kDefault,
+ string::Sprintf("server_%s_profile", _rank));
+ return 0;
+}
+
+int32_t GraphBrpcService::start_profiler(Table *table,
+ const PsRequestMessage &request,
+ PsResponseMessage &response,
+ brpc::Controller *cntl) {
+ platform::EnableProfiler(platform::ProfilerState::kCPU);
+ return 0;
+}
+
+int32_t GraphBrpcService::pull_graph_list(Table *table,
+ const PsRequestMessage &request,
+ PsResponseMessage &response,
+ brpc::Controller *cntl) {
+ CHECK_TABLE_EXIST(table, request, response)
+ if (request.params_size() < 3) {
+ set_response_code(response, -1,
+ "pull_graph_list request requires at least 3 arguments");
+ return 0;
+ }
+ int start = *(int *)(request.params(0).c_str());
+ int size = *(int *)(request.params(1).c_str());
+ int step = *(int *)(request.params(2).c_str());
+ std::unique_ptr buffer;
+ int actual_size;
+ ((GraphTable *)table)
+ ->pull_graph_list(start, size, buffer, actual_size, false, step);
+ cntl->response_attachment().append(buffer.get(), actual_size);
+ return 0;
+}
+int32_t GraphBrpcService::graph_random_sample_neighboors(
+ Table *table, const PsRequestMessage &request, PsResponseMessage &response,
+ brpc::Controller *cntl) {
+ CHECK_TABLE_EXIST(table, request, response)
+ if (request.params_size() < 2) {
+ set_response_code(
+ response, -1,
+ "graph_random_sample request requires at least 2 arguments");
+ return 0;
+ }
+ size_t node_num = request.params(0).size() / sizeof(uint64_t);
+ uint64_t *node_data = (uint64_t *)(request.params(0).c_str());
+ int sample_size = *(uint64_t *)(request.params(1).c_str());
+ std::vector> buffers(node_num);
+ std::vector actual_sizes(node_num, 0);
+ ((GraphTable *)table)
+ ->random_sample_neighboors(node_data, sample_size, buffers, actual_sizes);
+
+ cntl->response_attachment().append(&node_num, sizeof(size_t));
+ cntl->response_attachment().append(actual_sizes.data(),
+ sizeof(int) * node_num);
+ for (size_t idx = 0; idx < node_num; ++idx) {
+ cntl->response_attachment().append(buffers[idx].get(), actual_sizes[idx]);
+ }
+ return 0;
+}
+int32_t GraphBrpcService::graph_random_sample_nodes(
+ Table *table, const PsRequestMessage &request, PsResponseMessage &response,
+ brpc::Controller *cntl) {
+ size_t size = *(uint64_t *)(request.params(0).c_str());
+ std::unique_ptr buffer;
+ int actual_size;
+ if (((GraphTable *)table)->random_sample_nodes(size, buffer, actual_size) ==
+ 0) {
+ cntl->response_attachment().append(buffer.get(), actual_size);
+ } else
+ cntl->response_attachment().append(NULL, 0);
+
+ return 0;
+}
+
+int32_t GraphBrpcService::graph_get_node_feat(Table *table,
+ const PsRequestMessage &request,
+ PsResponseMessage &response,
+ brpc::Controller *cntl) {
+ CHECK_TABLE_EXIST(table, request, response)
+ if (request.params_size() < 2) {
+ set_response_code(
+ response, -1,
+ "graph_get_node_feat request requires at least 2 arguments");
+ return 0;
+ }
+ size_t node_num = request.params(0).size() / sizeof(uint64_t);
+ uint64_t *node_data = (uint64_t *)(request.params(0).c_str());
+ std::vector node_ids(node_data, node_data + node_num);
+
+ std::vector feature_names =
+ paddle::string::split_string(request.params(1), "\t");
+
+ std::vector> feature(
+ feature_names.size(), std::vector(node_num));
+
+ ((GraphTable *)table)->get_node_feat(node_ids, feature_names, feature);
+
+ for (size_t feat_idx = 0; feat_idx < feature_names.size(); ++feat_idx) {
+ for (size_t node_idx = 0; node_idx < node_num; ++node_idx) {
+ size_t feat_len = feature[feat_idx][node_idx].size();
+ cntl->response_attachment().append(&feat_len, sizeof(size_t));
+ cntl->response_attachment().append(feature[feat_idx][node_idx].data(),
+ feat_len);
+ }
+ }
+
+ return 0;
+}
+} // namespace distributed
+} // namespace paddle
diff --git a/paddle/fluid/distributed/service/graph_brpc_server.h b/paddle/fluid/distributed/service/graph_brpc_server.h
new file mode 100644
index 0000000000000000000000000000000000000000..32c572f9e6c2bf759c59190679bcf7570a807f2d
--- /dev/null
+++ b/paddle/fluid/distributed/service/graph_brpc_server.h
@@ -0,0 +1,114 @@
+// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#pragma once
+
+#include "brpc/channel.h"
+#include "brpc/controller.h"
+#include "brpc/server.h"
+
+#include
+#include
+#include "paddle/fluid/distributed/service/brpc_ps_server.h"
+#include "paddle/fluid/distributed/service/server.h"
+#include "paddle/fluid/distributed/table/common_graph_table.h"
+#include "paddle/fluid/distributed/table/table.h"
+namespace paddle {
+namespace distributed {
+class GraphBrpcServer : public PSServer {
+ public:
+ GraphBrpcServer() {}
+ virtual ~GraphBrpcServer() {}
+ PsBaseService *get_service() { return _service.get(); }
+ virtual uint64_t start(const std::string &ip, uint32_t port);
+ virtual int32_t stop() {
+ std::unique_lock lock(mutex_);
+ if (stoped_) return 0;
+ stoped_ = true;
+ // cv_.notify_all();
+ _server.Stop(1000);
+ _server.Join();
+ return 0;
+ }
+ virtual int32_t port();
+
+ std::condition_variable *export_cv() { return &cv_; }
+
+ private:
+ virtual int32_t initialize();
+ mutable std::mutex mutex_;
+ std::condition_variable cv_;
+ bool stoped_ = false;
+ brpc::Server _server;
+ std::shared_ptr _service;
+ std::vector> _pserver_channels;
+};
+
+class GraphBrpcService;
+
+typedef int32_t (GraphBrpcService::*serviceFunc)(
+ Table *table, const PsRequestMessage &request, PsResponseMessage &response,
+ brpc::Controller *cntl);
+
+class GraphBrpcService : public PsBaseService {
+ public:
+ virtual int32_t initialize() override;
+
+ virtual void service(::google::protobuf::RpcController *controller,
+ const PsRequestMessage *request,
+ PsResponseMessage *response,
+ ::google::protobuf::Closure *done) override;
+
+ protected:
+ std::unordered_map _service_handler_map;
+ int32_t initialize_shard_info();
+ int32_t pull_graph_list(Table *table, const PsRequestMessage &request,
+ PsResponseMessage &response, brpc::Controller *cntl);
+ int32_t graph_random_sample_neighboors(Table *table,
+ const PsRequestMessage &request,
+ PsResponseMessage &response,
+ brpc::Controller *cntl);
+ int32_t graph_random_sample_nodes(Table *table,
+ const PsRequestMessage &request,
+ PsResponseMessage &response,
+ brpc::Controller *cntl);
+ int32_t graph_get_node_feat(Table *table, const PsRequestMessage &request,
+ PsResponseMessage &response,
+ brpc::Controller *cntl);
+ int32_t barrier(Table *table, const PsRequestMessage &request,
+ PsResponseMessage &response, brpc::Controller *cntl);
+ int32_t load_one_table(Table *table, const PsRequestMessage &request,
+ PsResponseMessage &response, brpc::Controller *cntl);
+ int32_t load_all_table(Table *table, const PsRequestMessage &request,
+ PsResponseMessage &response, brpc::Controller *cntl);
+ int32_t stop_server(Table *table, const PsRequestMessage &request,
+ PsResponseMessage &response, brpc::Controller *cntl);
+ int32_t start_profiler(Table *table, const PsRequestMessage &request,
+ PsResponseMessage &response, brpc::Controller *cntl);
+ int32_t stop_profiler(Table *table, const PsRequestMessage &request,
+ PsResponseMessage &response, brpc::Controller *cntl);
+
+ int32_t print_table_stat(Table *table, const PsRequestMessage &request,
+ PsResponseMessage &response, brpc::Controller *cntl);
+
+ private:
+ bool _is_initialize_shard_info;
+ std::mutex _initialize_shard_mutex;
+ std::unordered_map _msg_handler_map;
+ std::vector _ori_values;
+ const int sample_nodes_ranges = 23;
+};
+
+} // namespace distributed
+} // namespace paddle
diff --git a/paddle/fluid/distributed/service/graph_py_service.cc b/paddle/fluid/distributed/service/graph_py_service.cc
new file mode 100644
index 0000000000000000000000000000000000000000..61e4e0cf7bb9155d25c630296c2b55a7d3400bfc
--- /dev/null
+++ b/paddle/fluid/distributed/service/graph_py_service.cc
@@ -0,0 +1,325 @@
+// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "paddle/fluid/distributed/service/graph_py_service.h"
+#include // NOLINT
+#include "butil/endpoint.h"
+#include "iomanip"
+#include "paddle/fluid/distributed/table/table.h"
+#include "paddle/fluid/framework/archive.h"
+#include "paddle/fluid/platform/profiler.h"
+namespace paddle {
+namespace distributed {
+std::vector GraphPyService::split(std::string& str,
+ const char pattern) {
+ std::vector res;
+ std::stringstream input(str);
+ std::string temp;
+ while (std::getline(input, temp, pattern)) {
+ res.push_back(temp);
+ }
+ return res;
+}
+
+void GraphPyService::add_table_feat_conf(std::string table_name,
+ std::string feat_name,
+ std::string feat_dtype,
+ int32_t feat_shape) {
+ if (this->table_id_map.count(table_name)) {
+ this->table_feat_conf_table_name.push_back(table_name);
+ this->table_feat_conf_feat_name.push_back(feat_name);
+ this->table_feat_conf_feat_dtype.push_back(feat_dtype);
+ this->table_feat_conf_feat_shape.push_back(feat_shape);
+ }
+}
+
+void GraphPyService::set_up(std::string ips_str, int shard_num,
+ std::vector node_types,
+ std::vector edge_types) {
+ set_shard_num(shard_num);
+ set_num_node_types(node_types.size());
+
+ for (size_t table_id = 0; table_id < node_types.size(); table_id++) {
+ this->table_id_map[node_types[table_id]] = this->table_id_map.size();
+ }
+ for (size_t table_id = 0; table_id < edge_types.size(); table_id++) {
+ this->table_id_map[edge_types[table_id]] = this->table_id_map.size();
+ }
+ std::istringstream stream(ips_str);
+ std::string ip;
+ server_size = 0;
+ std::vector ips_list = split(ips_str, ';');
+ int index = 0;
+ for (auto ips : ips_list) {
+ auto ip_and_port = split(ips, ':');
+ server_list.push_back(ip_and_port[0]);
+ port_list.push_back(ip_and_port[1]);
+ uint32_t port = stoul(ip_and_port[1]);
+ auto ph_host = paddle::distributed::PSHost(ip_and_port[0], port, index);
+ host_sign_list.push_back(ph_host.serialize_to_string());
+ index++;
+ }
+}
+void GraphPyClient::start_client() {
+ std::map> dense_regions;
+ dense_regions.insert(
+ std::pair>(0, {}));
+ auto regions = dense_regions[0];
+ ::paddle::distributed::PSParameter worker_proto = GetWorkerProto();
+ paddle::distributed::PaddlePSEnvironment _ps_env;
+ auto servers_ = host_sign_list.size();
+ _ps_env = paddle::distributed::PaddlePSEnvironment();
+ _ps_env.set_ps_servers(&host_sign_list, servers_);
+ worker_ptr = std::shared_ptr(
+ (paddle::distributed::GraphBrpcClient*)
+ paddle::distributed::PSClientFactory::create(worker_proto));
+ worker_ptr->configure(worker_proto, dense_regions, _ps_env, client_id);
+ worker_ptr->set_shard_num(get_shard_num());
+}
+void GraphPyServer::start_server(bool block) {
+ std::string ip = server_list[rank];
+ uint32_t port = std::stoul(port_list[rank]);
+ ::paddle::distributed::PSParameter server_proto = this->GetServerProto();
+
+ auto _ps_env = paddle::distributed::PaddlePSEnvironment();
+ _ps_env.set_ps_servers(&this->host_sign_list,
+ this->host_sign_list.size()); // test
+ pserver_ptr = std::shared_ptr(
+ (paddle::distributed::GraphBrpcServer*)
+ paddle::distributed::PSServerFactory::create(server_proto));
+ VLOG(0) << "pserver-ptr created ";
+ std::vector empty_vec;
+ framework::ProgramDesc empty_prog;
+ empty_vec.push_back(empty_prog);
+ pserver_ptr->configure(server_proto, _ps_env, rank, empty_vec);
+ pserver_ptr->start(ip, port);
+ std::condition_variable* cv_ = pserver_ptr->export_cv();
+ if (block) {
+ std::mutex mutex_;
+ std::unique_lock lock(mutex_);
+ cv_->wait(lock);
+ }
+}
+::paddle::distributed::PSParameter GraphPyServer::GetServerProto() {
+ // Generate server proto desc
+ ::paddle::distributed::PSParameter server_fleet_desc;
+ ::paddle::distributed::ServerParameter* server_proto =
+ server_fleet_desc.mutable_server_param();
+ ::paddle::distributed::DownpourServerParameter* downpour_server_proto =
+ server_proto->mutable_downpour_server_param();
+ ::paddle::distributed::ServerServiceParameter* server_service_proto =
+ downpour_server_proto->mutable_service_param();
+ server_service_proto->set_service_class("GraphBrpcService");
+ server_service_proto->set_server_class("GraphBrpcServer");
+ server_service_proto->set_client_class("GraphBrpcClient");
+ server_service_proto->set_start_server_port(0);
+ server_service_proto->set_server_thread_num(12);
+
+ for (auto& tuple : this->table_id_map) {
+ VLOG(0) << " make a new table " << tuple.second;
+ ::paddle::distributed::TableParameter* sparse_table_proto =
+ downpour_server_proto->add_downpour_table_param();
+ std::vector feat_name;
+ std::vector feat_dtype;
+ std::vector feat_shape;
+ for (size_t i = 0; i < this->table_feat_conf_table_name.size(); i++) {
+ if (tuple.first == table_feat_conf_table_name[i]) {
+ feat_name.push_back(table_feat_conf_feat_name[i]);
+ feat_dtype.push_back(table_feat_conf_feat_dtype[i]);
+ feat_shape.push_back(table_feat_conf_feat_shape[i]);
+ }
+ }
+ std::string table_type;
+ if (tuple.second < this->num_node_types) {
+ table_type = "node";
+ } else {
+ table_type = "edge";
+ }
+
+ GetDownpourSparseTableProto(sparse_table_proto, tuple.second, tuple.first,
+ table_type, feat_name, feat_dtype, feat_shape);
+ }
+
+ return server_fleet_desc;
+}
+
+::paddle::distributed::PSParameter GraphPyClient::GetWorkerProto() {
+ ::paddle::distributed::PSParameter worker_fleet_desc;
+ ::paddle::distributed::WorkerParameter* worker_proto =
+ worker_fleet_desc.mutable_worker_param();
+
+ ::paddle::distributed::DownpourWorkerParameter* downpour_worker_proto =
+ worker_proto->mutable_downpour_worker_param();
+
+ for (auto& tuple : this->table_id_map) {
+ VLOG(0) << " make a new table " << tuple.second;
+ ::paddle::distributed::TableParameter* worker_sparse_table_proto =
+ downpour_worker_proto->add_downpour_table_param();
+ std::vector feat_name;
+ std::vector feat_dtype;
+ std::vector feat_shape;
+ for (size_t i = 0; i < this->table_feat_conf_table_name.size(); i++) {
+ if (tuple.first == table_feat_conf_table_name[i]) {
+ feat_name.push_back(table_feat_conf_feat_name[i]);
+ feat_dtype.push_back(table_feat_conf_feat_dtype[i]);
+ feat_shape.push_back(table_feat_conf_feat_shape[i]);
+ }
+ }
+ std::string table_type;
+ if (tuple.second < this->num_node_types) {
+ table_type = "node";
+ } else {
+ table_type = "edge";
+ }
+
+ GetDownpourSparseTableProto(worker_sparse_table_proto, tuple.second,
+ tuple.first, table_type, feat_name, feat_dtype,
+ feat_shape);
+ }
+
+ ::paddle::distributed::ServerParameter* server_proto =
+ worker_fleet_desc.mutable_server_param();
+ ::paddle::distributed::DownpourServerParameter* downpour_server_proto =
+ server_proto->mutable_downpour_server_param();
+ ::paddle::distributed::ServerServiceParameter* server_service_proto =
+ downpour_server_proto->mutable_service_param();
+ server_service_proto->set_service_class("GraphBrpcService");
+ server_service_proto->set_server_class("GraphBrpcServer");
+ server_service_proto->set_client_class("GraphBrpcClient");
+ server_service_proto->set_start_server_port(0);
+ server_service_proto->set_server_thread_num(12);
+
+ for (auto& tuple : this->table_id_map) {
+ VLOG(0) << " make a new table " << tuple.second;
+ ::paddle::distributed::TableParameter* sparse_table_proto =
+ downpour_server_proto->add_downpour_table_param();
+ std::vector feat_name;
+ std::vector feat_dtype;
+ std::vector feat_shape;
+ for (size_t i = 0; i < this->table_feat_conf_table_name.size(); i++) {
+ if (tuple.first == table_feat_conf_table_name[i]) {
+ feat_name.push_back(table_feat_conf_feat_name[i]);
+ feat_dtype.push_back(table_feat_conf_feat_dtype[i]);
+ feat_shape.push_back(table_feat_conf_feat_shape[i]);
+ }
+ }
+ std::string table_type;
+ if (tuple.second < this->num_node_types) {
+ table_type = "node";
+ } else {
+ table_type = "edge";
+ }
+
+ GetDownpourSparseTableProto(sparse_table_proto, tuple.second, tuple.first,
+ table_type, feat_name, feat_dtype, feat_shape);
+ }
+
+ return worker_fleet_desc;
+}
+void GraphPyClient::load_edge_file(std::string name, std::string filepath,
+ bool reverse) {
+ // 'e' means load edge
+ std::string params = "e";
+ if (reverse) {
+ // 'e<' means load edges from $2 to $1
+ params += "<";
+ } else {
+ // 'e>' means load edges from $1 to $2
+ params += ">";
+ }
+ if (this->table_id_map.count(name)) {
+ VLOG(0) << "loadding data with type " << name << " from " << filepath;
+ uint32_t table_id = this->table_id_map[name];
+ auto status =
+ get_ps_client()->load(table_id, std::string(filepath), params);
+ status.wait();
+ }
+}
+
+void GraphPyClient::load_node_file(std::string name, std::string filepath) {
+ // 'n' means load nodes and 'node_type' follows
+ std::string params = "n" + name;
+ if (this->table_id_map.count(name)) {
+ uint32_t table_id = this->table_id_map[name];
+ auto status =
+ get_ps_client()->load(table_id, std::string(filepath), params);
+ status.wait();
+ }
+}
+std::vector>>
+GraphPyClient::batch_sample_neighboors(std::string name,
+ std::vector node_ids,
+ int sample_size) {
+ std::vector>> v;
+ if (this->table_id_map.count(name)) {
+ uint32_t table_id = this->table_id_map[name];
+ auto status =
+ worker_ptr->batch_sample_neighboors(table_id, node_ids, sample_size, v);
+ status.wait();
+ }
+ return v;
+}
+
+std::vector GraphPyClient::random_sample_nodes(std::string name,
+ int server_index,
+ int sample_size) {
+ std::vector v;
+ if (this->table_id_map.count(name)) {
+ uint32_t table_id = this->table_id_map[name];
+ auto status =
+ worker_ptr->random_sample_nodes(table_id, server_index, sample_size, v);
+ status.wait();
+ }
+ return v;
+}
+
+// (name, dtype, ndarray)
+std::vector> GraphPyClient::get_node_feat(
+ std::string node_type, std::vector node_ids,
+ std::vector feature_names) {
+ std::vector> v(
+ feature_names.size(), std::vector(node_ids.size()));
+ if (this->table_id_map.count(node_type)) {
+ uint32_t table_id = this->table_id_map[node_type];
+ auto status =
+ worker_ptr->get_node_feat(table_id, node_ids, feature_names, v);
+ status.wait();
+ }
+ return v;
+}
+
+std::vector GraphPyClient::pull_graph_list(std::string name,
+ int server_index,
+ int start, int size,
+ int step) {
+ std::vector res;
+ if (this->table_id_map.count(name)) {
+ uint32_t table_id = this->table_id_map[name];
+ auto status = worker_ptr->pull_graph_list(table_id, server_index, start,
+ size, step, res);
+ status.wait();
+ }
+ return res;
+}
+
+void GraphPyClient::stop_server() {
+ VLOG(0) << "going to stop server";
+ std::unique_lock lock(mutex_);
+ if (stoped_) return;
+ auto status = this->worker_ptr->stop_server();
+ if (status.get() == 0) stoped_ = true;
+}
+void GraphPyClient::finalize_worker() { this->worker_ptr->finalize_worker(); }
+}
+}
diff --git a/paddle/fluid/distributed/service/graph_py_service.h b/paddle/fluid/distributed/service/graph_py_service.h
new file mode 100644
index 0000000000000000000000000000000000000000..c6657be96ba446d2f7538943aab43dd47e1868fb
--- /dev/null
+++ b/paddle/fluid/distributed/service/graph_py_service.h
@@ -0,0 +1,166 @@
+// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#pragma once
+#include
+#include // NOLINT
+#include
+#include
+#include
+#include
+#include
+#include // NOLINT
+#include
+#include
+#include "google/protobuf/text_format.h"
+
+#include "gtest/gtest.h"
+#include "paddle/fluid/framework/lod_tensor.h"
+#include "paddle/fluid/framework/scope.h"
+#include "paddle/fluid/framework/tensor_util.h"
+#include "paddle/fluid/framework/variable.h"
+
+#include "paddle/fluid/distributed/ps.pb.h"
+#include "paddle/fluid/distributed/service/env.h"
+#include "paddle/fluid/distributed/service/graph_brpc_client.h"
+#include "paddle/fluid/distributed/service/graph_brpc_server.h"
+#include "paddle/fluid/distributed/service/sendrecv.pb.h"
+#include "paddle/fluid/distributed/service/service.h"
+#include "paddle/fluid/framework/program_desc.h"
+#include "paddle/fluid/operators/math/math_function.h"
+#include "paddle/fluid/platform/place.h"
+#include "paddle/fluid/string/printf.h"
+namespace paddle {
+namespace distributed {
+class GraphPyService {
+ protected:
+ std::vector server_list, port_list, host_sign_list;
+ int server_size, shard_num;
+ int num_node_types;
+ std::unordered_map table_id_map;
+ std::vector table_feat_conf_table_name;
+ std::vector table_feat_conf_feat_name;
+ std::vector table_feat_conf_feat_dtype;
+ std::vector table_feat_conf_feat_shape;
+
+ public:
+ int get_shard_num() { return shard_num; }
+ void set_shard_num(int shard_num) { this->shard_num = shard_num; }
+ void GetDownpourSparseTableProto(
+ ::paddle::distributed::TableParameter* sparse_table_proto,
+ uint32_t table_id, std::string table_name, std::string table_type,
+ std::vector feat_name, std::vector feat_dtype,
+ std::vector feat_shape) {
+ sparse_table_proto->set_table_id(table_id);
+ sparse_table_proto->set_table_class("GraphTable");
+ sparse_table_proto->set_shard_num(shard_num);
+ sparse_table_proto->set_type(::paddle::distributed::PS_SPARSE_TABLE);
+ ::paddle::distributed::TableAccessorParameter* accessor_proto =
+ sparse_table_proto->mutable_accessor();
+
+ ::paddle::distributed::CommonAccessorParameter* common_proto =
+ sparse_table_proto->mutable_common();
+
+ // Set GraphTable Parameter
+ common_proto->set_table_name(table_name);
+ common_proto->set_name(table_type);
+ for (size_t i = 0; i < feat_name.size(); i++) {
+ common_proto->add_params(feat_dtype[i]);
+ common_proto->add_dims(feat_shape[i]);
+ common_proto->add_attributes(feat_name[i]);
+ }
+
+ accessor_proto->set_accessor_class("CommMergeAccessor");
+ }
+
+ void set_server_size(int server_size) { this->server_size = server_size; }
+ void set_num_node_types(int num_node_types) {
+ this->num_node_types = num_node_types;
+ }
+ int get_server_size(int server_size) { return server_size; }
+ std::vector split(std::string& str, const char pattern);
+ void set_up(std::string ips_str, int shard_num,
+ std::vector