提交 2912d531 编写于 作者: H heqiaozhi

fix code style bug & change pslib.cmake & change Cmakelist adapt pslib

上级 c59cdf3a
...@@ -65,6 +65,7 @@ option(WITH_GOLANG "Compile PaddlePaddle with GOLANG" OFF) ...@@ -65,6 +65,7 @@ option(WITH_GOLANG "Compile PaddlePaddle with GOLANG" OFF)
option(GLIDE_INSTALL "Download and install go dependencies " ON) option(GLIDE_INSTALL "Download and install go dependencies " ON)
option(USE_NNPACK "Compile PaddlePaddle with NNPACK library" OFF) option(USE_NNPACK "Compile PaddlePaddle with NNPACK library" OFF)
option(WITH_DISTRIBUTE "Compile with distributed support" OFF) option(WITH_DISTRIBUTE "Compile with distributed support" OFF)
option(WITH_PSLIB "Compile with pslib support" OFF)
option(USE_EIGEN_FOR_BLAS "Use matrix multiplication in Eigen" OFF) option(USE_EIGEN_FOR_BLAS "Use matrix multiplication in Eigen" OFF)
option(EIGEN_USE_THREADS "Compile with multi-threaded Eigen" OFF) option(EIGEN_USE_THREADS "Compile with multi-threaded Eigen" OFF)
option(WITH_ARM_FP16 "Use half precision support on armv8.2-a cpu" OFF) option(WITH_ARM_FP16 "Use half precision support on armv8.2-a cpu" OFF)
...@@ -216,9 +217,12 @@ include(external/warpctc) # download, build, install warpctc ...@@ -216,9 +217,12 @@ include(external/warpctc) # download, build, install warpctc
include(cupti) include(cupti)
include(external/gzstream) include(external/gzstream)
endif (NOT WIN32) endif (NOT WIN32)
include(external/libmct)
include(external/pslib_brpc) if(WITH_PSLIB)
include(external/pslib) include(external/libmct)
include(external/pslib_brpc)
include(external/pslib)
endif()
if(WITH_DISTRIBUTE) if(WITH_DISTRIBUTE)
if(WITH_GRPC) if(WITH_GRPC)
...@@ -279,11 +283,14 @@ set(EXTERNAL_LIBS ...@@ -279,11 +283,14 @@ set(EXTERNAL_LIBS
protobuf protobuf
zlib zlib
${PYTHON_LIBRARIES} ${PYTHON_LIBRARIES}
pslib
pslib_brpc
libmct
) )
if(WITH_PSLIB)
list(APPEND EXTERNAL_LIBS pslib)
list(APPEND EXTERNAL_LIBS pslib_brpc)
list(APPEND EXTERNAL_LIBS libmct)
endif(WITH_PSLIB)
if(WITH_AMD_GPU) if(WITH_AMD_GPU)
find_package(HIP) find_package(HIP)
include(hip) include(hip)
......
...@@ -30,9 +30,10 @@ SET(PSLIB_PROJECT "extern_pslib") ...@@ -30,9 +30,10 @@ SET(PSLIB_PROJECT "extern_pslib")
IF((NOT DEFINED PSLIB_VER) OR (NOT DEFINED PSLIB_URL)) IF((NOT DEFINED PSLIB_VER) OR (NOT DEFINED PSLIB_URL))
MESSAGE(STATUS "use pre defined download url") MESSAGE(STATUS "use pre defined download url")
SET(PSLIB_VER "0.1.0" CACHE STRING "" FORCE) SET(PSLIB_VER "0.1.0" CACHE STRING "" FORCE)
SET(PSLIB_URL "https://raw.githubusercontent.com/PaddlePaddle/Fleet/release/${PSLIB_VER}/pslib.tar.gz" CACHE STRING "" FORCE) SET(PSLIB_NAME "pslib" CACHE STRING "" FORCE)
SET(PSLIB_URL "https://raw.githubusercontent.com/PaddlePaddle/Fleet/release/${PSLIB_VER}/${PSLIB_NAME}.tar.gz" CACHE STRING "" FORCE)
ENDIF() ENDIF()
MESSAGE(STATUS "PSLIB_VER: ${PSLIB_VER}, PSLIB_URL: ${PSLIB_URL}") MESSAGE(STATUS "PSLIB_NAME: ${PSLIB_NAME}, PSLIB_URL: ${PSLIB_URL}")
SET(PSLIB_SOURCE_DIR "${THIRD_PARTY_PATH}/pslib") SET(PSLIB_SOURCE_DIR "${THIRD_PARTY_PATH}/pslib")
SET(PSLIB_DOWNLOAD_DIR "${PSLIB_SOURCE_DIR}/src/${PSLIB_PROJECT}") SET(PSLIB_DOWNLOAD_DIR "${PSLIB_SOURCE_DIR}/src/${PSLIB_PROJECT}")
SET(PSLIB_DST_DIR "pslib") SET(PSLIB_DST_DIR "pslib")
...@@ -50,7 +51,7 @@ INCLUDE_DIRECTORIES(${PSLIB_INC_DIR}) ...@@ -50,7 +51,7 @@ INCLUDE_DIRECTORIES(${PSLIB_INC_DIR})
FILE(WRITE ${PSLIB_DOWNLOAD_DIR}/CMakeLists.txt FILE(WRITE ${PSLIB_DOWNLOAD_DIR}/CMakeLists.txt
"PROJECT(PSLIB)\n" "PROJECT(PSLIB)\n"
"cmake_minimum_required(VERSION 3.0)\n" "cmake_minimum_required(VERSION 3.0)\n"
"install(DIRECTORY ${PSLIB_VER}/include ${PSLIB_VER}/lib \n" "install(DIRECTORY ${PSLIB_NAME}/include ${PSLIB_NAME}/lib \n"
" DESTINATION ${PSLIB_DST_DIR})\n") " DESTINATION ${PSLIB_DST_DIR})\n")
ExternalProject_Add( ExternalProject_Add(
...@@ -58,8 +59,8 @@ ExternalProject_Add( ...@@ -58,8 +59,8 @@ ExternalProject_Add(
${EXTERNAL_PROJECT_LOG_ARGS} ${EXTERNAL_PROJECT_LOG_ARGS}
PREFIX ${PSLIB_SOURCE_DIR} PREFIX ${PSLIB_SOURCE_DIR}
DOWNLOAD_DIR ${PSLIB_DOWNLOAD_DIR} DOWNLOAD_DIR ${PSLIB_DOWNLOAD_DIR}
DOWNLOAD_COMMAND wget --no-check-certificate ${PSLIB_URL} -c -q -O ${PSLIB_VER}.tar.gz DOWNLOAD_COMMAND wget --no-check-certificate ${PSLIB_URL} -c -q -O ${PSLIB_NAME}.tar.gz
&& tar zxvf ${PSLIB_VER}.tar.gz && tar zxvf ${PSLIB_NAME}.tar.gz
DOWNLOAD_NO_PROGRESS 1 DOWNLOAD_NO_PROGRESS 1
UPDATE_COMMAND "" UPDATE_COMMAND ""
CMAKE_ARGS -DCMAKE_INSTALL_PREFIX=${PSLIB_INSTALL_ROOT} CMAKE_ARGS -DCMAKE_INSTALL_PREFIX=${PSLIB_INSTALL_ROOT}
......
...@@ -50,7 +50,6 @@ void AsyncExecutor::CreateThreads( ...@@ -50,7 +50,6 @@ void AsyncExecutor::CreateThreads(
worker->BindingDataFeedMemory(); worker->BindingDataFeedMemory();
worker->SetPSlibPtr(_pslib_ptr); worker->SetPSlibPtr(_pslib_ptr);
worker->SetPullDenseThread(_pull_dense_thread); worker->SetPullDenseThread(_pull_dense_thread);
worker->BindingSlotVariableMemory();
worker->SetParamConfig(&_param_config); worker->SetParamConfig(&_param_config);
} }
...@@ -79,7 +78,7 @@ void AsyncExecutor::InitWorker(const std::string& dist_desc, ...@@ -79,7 +78,7 @@ void AsyncExecutor::InitWorker(const std::string& dist_desc,
_pslib_ptr = std::shared_ptr<paddle::distributed::PSlib>( _pslib_ptr = std::shared_ptr<paddle::distributed::PSlib>(
new paddle::distributed::PSlib()); new paddle::distributed::PSlib());
_pslib_ptr->init_worker( _pslib_ptr->init_worker(
dist_desc, host_sign_list.data(), node_num, index); dist_desc, (uint64_t*)(host_sign_list.data()), node_num, index);
InitParamConfig(); InitParamConfig();
} }
...@@ -93,8 +92,8 @@ void AsyncExecutor::StopServer() { ...@@ -93,8 +92,8 @@ void AsyncExecutor::StopServer() {
} }
void AsyncExecutor::GatherServers( void AsyncExecutor::GatherServers(
std::vector<uint64_t>& host_sign_list, int node_num) { const std::vector<uint64_t>& host_sign_list, int node_num) {
_pslib_ptr->gather_servers(host_sign_list.data(), node_num); _pslib_ptr->gather_servers((uint64_t*)(host_sign_list.data()), node_num);
} }
void AsyncExecutor::InitParamConfig() { void AsyncExecutor::InitParamConfig() {
......
...@@ -43,9 +43,9 @@ inline std::default_random_engine& local_random_engine() { ...@@ -43,9 +43,9 @@ inline std::default_random_engine& local_random_engine() {
struct engine_wrapper_t { struct engine_wrapper_t {
std::default_random_engine engine; std::default_random_engine engine;
engine_wrapper_t() { engine_wrapper_t() {
static std::atomic<uint64> x(0); static std::atomic<uint64_t> x(0);
std::seed_seq sseq = {x++, x++, x++, std::seed_seq sseq = {x++, x++, x++,
static_cast<uint64>(current_realtime() * 1000)}; static_cast<uint64_t>(current_realtime() * 1000)};
engine.seed(sseq); engine.seed(sseq);
} }
}; };
......
...@@ -68,6 +68,7 @@ bool DataFeed::PickOneFile(std::string* filename) { ...@@ -68,6 +68,7 @@ bool DataFeed::PickOneFile(std::string* filename) {
return false; return false;
} }
*filename = filelist_[file_idx_++]; *filename = filelist_[file_idx_++];
LOG(ERROR) << "pick file:" << *filename;
return true; return true;
} }
......
...@@ -637,7 +637,7 @@ void AsyncExecutorThreadWorker::collect_feasign_info( ...@@ -637,7 +637,7 @@ void AsyncExecutorThreadWorker::collect_feasign_info(
} }
void AsyncExecutorThreadWorker::check_pull_push_memory( void AsyncExecutorThreadWorker::check_pull_push_memory(
std::vector<uint64_t>& features, const std::vector<uint64_t>& features,
std::vector<std::vector<float>>& push_g, std::vector<std::vector<float>>& push_g,
int dim) { int dim) {
push_g.resize(features.size() + 1); push_g.resize(features.size() + 1);
...@@ -647,7 +647,7 @@ void AsyncExecutorThreadWorker::check_pull_push_memory( ...@@ -647,7 +647,7 @@ void AsyncExecutorThreadWorker::check_pull_push_memory(
} }
void AsyncExecutorThreadWorker::check_pull_push_memory( void AsyncExecutorThreadWorker::check_pull_push_memory(
std::vector<uint64_t>& features, const std::vector<uint64_t>& features,
std::vector<float*>& push_g, std::vector<float*>& push_g,
int dim) { int dim) {
if (features.size() > push_g.size()) { if (features.size() > push_g.size()) {
......
...@@ -155,7 +155,7 @@ class ExecutorThreadWorker { ...@@ -155,7 +155,7 @@ class ExecutorThreadWorker {
// set fetch variable names from python interface assigned by users // set fetch variable names from python interface assigned by users
void SetFetchVarNames(const std::vector<std::string>& fetch_var_names); void SetFetchVarNames(const std::vector<std::string>& fetch_var_names);
virtual void SetPSlibPtr( virtual void SetPSlibPtr(
std::shared_ptr<paddle::distributed::PSlib> pslib_ptr); std::shared_ptr<paddle::distributed::PSlib> pslib_ptr) {};
virtual void SetPullDenseThread( virtual void SetPullDenseThread(
std::shared_ptr<DensePullThread> dpt) {} std::shared_ptr<DensePullThread> dpt) {}
virtual void SetParamConfig( virtual void SetParamConfig(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册