提交 2912d531 编写于 作者: H heqiaozhi

fix code style bug & change pslib.cmake & change Cmakelist adapt pslib

上级 c59cdf3a
......@@ -65,6 +65,7 @@ option(WITH_GOLANG "Compile PaddlePaddle with GOLANG" OFF)
option(GLIDE_INSTALL "Download and install go dependencies " ON)
option(USE_NNPACK "Compile PaddlePaddle with NNPACK library" OFF)
option(WITH_DISTRIBUTE "Compile with distributed support" OFF)
option(WITH_PSLIB "Compile with pslib support" OFF)
option(USE_EIGEN_FOR_BLAS "Use matrix multiplication in Eigen" OFF)
option(EIGEN_USE_THREADS "Compile with multi-threaded Eigen" OFF)
option(WITH_ARM_FP16 "Use half precision support on armv8.2-a cpu" OFF)
......@@ -216,9 +217,12 @@ include(external/warpctc) # download, build, install warpctc
include(cupti)
include(external/gzstream)
endif (NOT WIN32)
include(external/libmct)
include(external/pslib_brpc)
include(external/pslib)
if(WITH_PSLIB)
include(external/libmct)
include(external/pslib_brpc)
include(external/pslib)
endif()
if(WITH_DISTRIBUTE)
if(WITH_GRPC)
......@@ -279,11 +283,14 @@ set(EXTERNAL_LIBS
protobuf
zlib
${PYTHON_LIBRARIES}
pslib
pslib_brpc
libmct
)
if(WITH_PSLIB)
list(APPEND EXTERNAL_LIBS pslib)
list(APPEND EXTERNAL_LIBS pslib_brpc)
list(APPEND EXTERNAL_LIBS libmct)
endif(WITH_PSLIB)
if(WITH_AMD_GPU)
find_package(HIP)
include(hip)
......
......@@ -30,9 +30,10 @@ SET(PSLIB_PROJECT "extern_pslib")
IF((NOT DEFINED PSLIB_VER) OR (NOT DEFINED PSLIB_URL))
MESSAGE(STATUS "use pre defined download url")
SET(PSLIB_VER "0.1.0" CACHE STRING "" FORCE)
SET(PSLIB_URL "https://raw.githubusercontent.com/PaddlePaddle/Fleet/release/${PSLIB_VER}/pslib.tar.gz" CACHE STRING "" FORCE)
SET(PSLIB_NAME "pslib" CACHE STRING "" FORCE)
SET(PSLIB_URL "https://raw.githubusercontent.com/PaddlePaddle/Fleet/release/${PSLIB_VER}/${PSLIB_NAME}.tar.gz" CACHE STRING "" FORCE)
ENDIF()
MESSAGE(STATUS "PSLIB_VER: ${PSLIB_VER}, PSLIB_URL: ${PSLIB_URL}")
MESSAGE(STATUS "PSLIB_NAME: ${PSLIB_NAME}, PSLIB_URL: ${PSLIB_URL}")
SET(PSLIB_SOURCE_DIR "${THIRD_PARTY_PATH}/pslib")
SET(PSLIB_DOWNLOAD_DIR "${PSLIB_SOURCE_DIR}/src/${PSLIB_PROJECT}")
SET(PSLIB_DST_DIR "pslib")
......@@ -50,7 +51,7 @@ INCLUDE_DIRECTORIES(${PSLIB_INC_DIR})
FILE(WRITE ${PSLIB_DOWNLOAD_DIR}/CMakeLists.txt
"PROJECT(PSLIB)\n"
"cmake_minimum_required(VERSION 3.0)\n"
"install(DIRECTORY ${PSLIB_VER}/include ${PSLIB_VER}/lib \n"
"install(DIRECTORY ${PSLIB_NAME}/include ${PSLIB_NAME}/lib \n"
" DESTINATION ${PSLIB_DST_DIR})\n")
ExternalProject_Add(
......@@ -58,8 +59,8 @@ ExternalProject_Add(
${EXTERNAL_PROJECT_LOG_ARGS}
PREFIX ${PSLIB_SOURCE_DIR}
DOWNLOAD_DIR ${PSLIB_DOWNLOAD_DIR}
DOWNLOAD_COMMAND wget --no-check-certificate ${PSLIB_URL} -c -q -O ${PSLIB_VER}.tar.gz
&& tar zxvf ${PSLIB_VER}.tar.gz
DOWNLOAD_COMMAND wget --no-check-certificate ${PSLIB_URL} -c -q -O ${PSLIB_NAME}.tar.gz
&& tar zxvf ${PSLIB_NAME}.tar.gz
DOWNLOAD_NO_PROGRESS 1
UPDATE_COMMAND ""
CMAKE_ARGS -DCMAKE_INSTALL_PREFIX=${PSLIB_INSTALL_ROOT}
......
......@@ -50,7 +50,6 @@ void AsyncExecutor::CreateThreads(
worker->BindingDataFeedMemory();
worker->SetPSlibPtr(_pslib_ptr);
worker->SetPullDenseThread(_pull_dense_thread);
worker->BindingSlotVariableMemory();
worker->SetParamConfig(&_param_config);
}
......@@ -79,7 +78,7 @@ void AsyncExecutor::InitWorker(const std::string& dist_desc,
_pslib_ptr = std::shared_ptr<paddle::distributed::PSlib>(
new paddle::distributed::PSlib());
_pslib_ptr->init_worker(
dist_desc, host_sign_list.data(), node_num, index);
dist_desc, (uint64_t*)(host_sign_list.data()), node_num, index);
InitParamConfig();
}
......@@ -93,8 +92,8 @@ void AsyncExecutor::StopServer() {
}
void AsyncExecutor::GatherServers(
std::vector<uint64_t>& host_sign_list, int node_num) {
_pslib_ptr->gather_servers(host_sign_list.data(), node_num);
const std::vector<uint64_t>& host_sign_list, int node_num) {
_pslib_ptr->gather_servers((uint64_t*)(host_sign_list.data()), node_num);
}
void AsyncExecutor::InitParamConfig() {
......
......@@ -43,9 +43,9 @@ inline std::default_random_engine& local_random_engine() {
struct engine_wrapper_t {
std::default_random_engine engine;
engine_wrapper_t() {
static std::atomic<uint64> x(0);
static std::atomic<uint64_t> x(0);
std::seed_seq sseq = {x++, x++, x++,
static_cast<uint64>(current_realtime() * 1000)};
static_cast<uint64_t>(current_realtime() * 1000)};
engine.seed(sseq);
}
};
......
......@@ -68,6 +68,7 @@ bool DataFeed::PickOneFile(std::string* filename) {
return false;
}
*filename = filelist_[file_idx_++];
LOG(ERROR) << "pick file:" << *filename;
return true;
}
......
......@@ -637,7 +637,7 @@ void AsyncExecutorThreadWorker::collect_feasign_info(
}
void AsyncExecutorThreadWorker::check_pull_push_memory(
std::vector<uint64_t>& features,
const std::vector<uint64_t>& features,
std::vector<std::vector<float>>& push_g,
int dim) {
push_g.resize(features.size() + 1);
......@@ -647,7 +647,7 @@ void AsyncExecutorThreadWorker::check_pull_push_memory(
}
void AsyncExecutorThreadWorker::check_pull_push_memory(
std::vector<uint64_t>& features,
const std::vector<uint64_t>& features,
std::vector<float*>& push_g,
int dim) {
if (features.size() > push_g.size()) {
......
......@@ -155,7 +155,7 @@ class ExecutorThreadWorker {
// set fetch variable names from python interface assigned by users
void SetFetchVarNames(const std::vector<std::string>& fetch_var_names);
virtual void SetPSlibPtr(
std::shared_ptr<paddle::distributed::PSlib> pslib_ptr);
std::shared_ptr<paddle::distributed::PSlib> pslib_ptr) {};
virtual void SetPullDenseThread(
std::shared_ptr<DensePullThread> dpt) {}
virtual void SetParamConfig(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册