提交 f81957a7 编写于 作者: H heqiaozhi

refine cmake for pslib & pre_define

上级 2912d531
......@@ -222,7 +222,7 @@ if(WITH_PSLIB)
include(external/libmct)
include(external/pslib_brpc)
include(external/pslib)
endif()
endif(WITH_PSLIB)
if(WITH_DISTRIBUTE)
if(WITH_GRPC)
......
......@@ -84,6 +84,10 @@ if(NOT WITH_GOLANG)
add_definitions(-DPADDLE_WITHOUT_GOLANG)
endif(NOT WITH_GOLANG)
if(WITH_PSLIB)
add_definitions(-DPADDLE_WITH_PSLIB)
endif()
if(WITH_GPU)
add_definitions(-DPADDLE_WITH_CUDA)
......
......@@ -29,10 +29,11 @@ INCLUDE(ExternalProject)
SET(LIBMCT_PROJECT "extern_libmct")
IF((NOT DEFINED LIBMCT_VER) OR (NOT DEFINED LIBMCT_URL))
MESSAGE(STATUS "use pre defined download url")
SET(LIBMCT_VER "libmct" CACHE STRING "" FORCE) #todo libmct version
SET(LIBMCT_URL "http://bjyz-heqiaozhi-dev-new.epc.baidu.com:8000/${LIBMCT_VER}.tar.gz" CACHE STRING "" FORCE) #todo libmct url
SET(LIBMCT_VER "0.1.0" CACHE STRING "" FORCE)
SET(LIBMCT_NAME "libmct" CACHE STRING "" FORCE)
SET(LIBMCT_URL "https://raw.githubusercontent.com/PaddlePaddle/Fleet/release/${LIBMCT_VER}/${LIBMCT_NAME}.tar.gz" CACHE STRING "" FORCE)
ENDIF()
MESSAGE(STATUS "LIBMCT_VER: ${LIBMCT_VER}, LIBMCT_URL: ${LIBMCT_URL}")
MESSAGE(STATUS "LIBMCT_NAME: ${LIBMCT_NAME}, LIBMCT_URL: ${LIBMCT_URL}")
SET(LIBMCT_SOURCE_DIR "${THIRD_PARTY_PATH}/libmct")
SET(LIBMCT_DOWNLOAD_DIR "${LIBMCT_SOURCE_DIR}/src/${LIBMCT_PROJECT}")
SET(LIBMCT_DST_DIR "libmct")
......@@ -47,7 +48,7 @@ INCLUDE_DIRECTORIES(${LIBMCT_INC_DIR})
FILE(WRITE ${LIBMCT_DOWNLOAD_DIR}/CMakeLists.txt
"PROJECT(LIBMCT)\n"
"cmake_minimum_required(VERSION 3.0)\n"
"install(DIRECTORY ${LIBMCT_VER}/include ${LIBMCT_VER}/lib \n"
"install(DIRECTORY ${LIBMCT_NAME}/include ${LIBMCT_NAME}/lib \n"
" DESTINATION ${LIBMCT_DST_DIR})\n")
ExternalProject_Add(
......@@ -55,8 +56,8 @@ ExternalProject_Add(
${EXTERNAL_PROJECT_LOG_ARGS}
PREFIX ${LIBMCT_SOURCE_DIR}
DOWNLOAD_DIR ${LIBMCT_DOWNLOAD_DIR}
DOWNLOAD_COMMAND wget --no-check-certificate ${LIBMCT_URL} -c -q -O ${LIBMCT_VER}.tar.gz
&& tar zxvf ${LIBMCT_VER}.tar.gz
DOWNLOAD_COMMAND wget --no-check-certificate ${LIBMCT_URL} -c -q -O ${LIBMCT_NAME}.tar.gz
&& tar zxvf ${LIBMCT_NAME}.tar.gz
DOWNLOAD_NO_PROGRESS 1
UPDATE_COMMAND ""
CMAKE_ARGS -DCMAKE_INSTALL_PREFIX=${LIBMCT_INSTALL_ROOT}
......
......@@ -27,12 +27,13 @@ ENDIF()
INCLUDE(ExternalProject)
SET(PSLIB_BRPC_PROJECT "extern_pslib_brpc")
IF((NOT DEFINED PSLIB_BRPC_VER) OR (NOT DEFINED PSLIB_BRPC_URL))
IF((NOT DEFINED PSLIB_BRPC_NAME) OR (NOT DEFINED PSLIB_BRPC_URL))
MESSAGE(STATUS "use pre defined download url")
SET(PSLIB_BRPC_VER "pslib_brpc" CACHE STRING "" FORCE) #todo pslib version
SET(PSLIB_BRPC_URL "http://bjyz-heqiaozhi-dev-new.epc.baidu.com:8000/${PSLIB_BRPC_VER}.tar.gz" CACHE STRING "" FORCE) #todo pslib_brpc url
SET(PSLIB_BRPC_VER "0.1.0" CACHE STRING "" FORCE)
SET(PSLIB_BRPC_NAME "pslib_brpc" CACHE STRING "" FORCE)
SET(PSLIB_BRPC_URL "https://raw.githubusercontent.com/PaddlePaddle/Fleet/release/${PSLIB_BRPC_VER}/${PSLIB_BRPC_NAME}.tar.gz" CACHE STRING "" FORCE)
ENDIF()
MESSAGE(STATUS "PSLIB_BRPC_VER: ${PSLIB_BRPC_VER}, PSLIB_BRPC_URL: ${PSLIB_BRPC_URL}")
MESSAGE(STATUS "PSLIB_BRPC_NAME: ${PSLIB_BRPC_NAME}, PSLIB_BRPC_URL: ${PSLIB_BRPC_URL}")
SET(PSLIB_BRPC_SOURCE_DIR "${THIRD_PARTY_PATH}/pslib_brpc")
SET(PSLIB_BRPC_DOWNLOAD_DIR "${PSLIB_BRPC_SOURCE_DIR}/src/${PSLIB_BRPC_PROJECT}")
SET(PSLIB_BRPC_DST_DIR "pslib_brpc")
......@@ -50,7 +51,7 @@ INCLUDE_DIRECTORIES(${PSLIB_BRPC_INC_DIR})
FILE(WRITE ${PSLIB_BRPC_DOWNLOAD_DIR}/CMakeLists.txt
"PROJECT(PSLIB_BRPC)\n"
"cmake_minimum_required(VERSION 3.0)\n"
"install(DIRECTORY ${PSLIB_BRPC_VER}/include ${PSLIB_BRPC_VER}/lib \n"
"install(DIRECTORY ${PSLIB_BRPC_NAME}/include ${PSLIB_BRPC_NAME}/lib \n"
" DESTINATION ${PSLIB_BRPC_DST_DIR})\n")
ExternalProject_Add(
......@@ -58,8 +59,8 @@ ExternalProject_Add(
${EXTERNAL_PROJECT_LOG_ARGS}
PREFIX ${PSLIB_BRPC_SOURCE_DIR}
DOWNLOAD_DIR ${PSLIB_BRPC_DOWNLOAD_DIR}
DOWNLOAD_COMMAND wget --no-check-certificate ${PSLIB_BRPC_URL} -c -q -O ${PSLIB_BRPC_VER}.tar.gz
&& tar zxvf ${PSLIB_BRPC_VER}.tar.gz
DOWNLOAD_COMMAND wget --no-check-certificate ${PSLIB_BRPC_URL} -c -q -O ${PSLIB_BRPC_NAME}.tar.gz
&& tar zxvf ${PSLIB_BRPC_NAME}.tar.gz
DOWNLOAD_NO_PROGRESS 1
UPDATE_COMMAND ""
CMAKE_ARGS -DCMAKE_INSTALL_PREFIX=${PSLIB_BRPC_INSTALL_ROOT}
......
......@@ -180,7 +180,12 @@ cc_library(parallel_executor SRCS parallel_executor.cc DEPS
graph build_strategy
fast_threaded_ssa_graph_executor variable_helper)
cc_library(async_executor SRCS async_executor.cc data_feed.cc data_feed_factory.cc executor_thread_worker.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method graph_to_program_pass async_executor_proto variable_helper pslib_brpc pslib)
if(WITH_PSLIB)
cc_library(async_executor SRCS async_executor.cc data_feed.cc data_feed_factory.cc executor_thread_worker.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method graph_to_program_pass async_executor_proto variable_helper pslib_brpc pslib)
else()
cc_library(async_executor SRCS async_executor.cc data_feed.cc data_feed_factory.cc executor_thread_worker.cc DEPS op_registry device_context scope framework_proto glog lod_rank_table feed_fetch_method graph_to_program_pass async_executor_proto variable_helper)
endif(WITH_PSLIB)
cc_test(data_feed_test SRCS data_feed_test.cc DEPS async_executor)
cc_library(prune SRCS prune.cc DEPS framework_proto)
......
......@@ -29,7 +29,9 @@ limitations under the License. */
#include "paddle/fluid/inference/io.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/pybind/pybind.h"
#ifdef PADDLE_WITH_PSLIB
#include "pslib.h"
#endif
namespace paddle {
namespace framework {
......@@ -48,9 +50,11 @@ void AsyncExecutor::CreateThreads(
worker->SetDataFeed(reader);
worker->SetFetchVarNames(fetch_var_names);
worker->BindingDataFeedMemory();
#ifdef PADDLE_WITH_PSLIB
worker->SetPSlibPtr(_pslib_ptr);
worker->SetPullDenseThread(_pull_dense_thread);
worker->SetParamConfig(&_param_config);
#endif
}
void PrepareReaders(std::vector<std::shared_ptr<DataFeed>>& readers, // NOLINT
......@@ -64,6 +68,7 @@ void PrepareReaders(std::vector<std::shared_ptr<DataFeed>>& readers, // NOLINT
readers[0]->SetFileList(filelist);
}
#ifdef PADDLE_WITH_PSLIB
void AsyncExecutor::InitServer(const std::string& dist_desc, int index) {
_pslib_ptr =
std::shared_ptr<paddle::distributed::PSlib>(
......@@ -231,6 +236,7 @@ void AsyncExecutor::PrepareDenseThread(const std::string& mode) {
_pull_dense_thread->start();
}
}
#endif
void AsyncExecutor::RunFromFile(const ProgramDesc& main_program,
const std::string& data_feed_desc_str,
......@@ -279,15 +285,21 @@ void AsyncExecutor::RunFromFile(const ProgramDesc& main_program,
// todo: should be factory method for creating datafeed
std::vector<std::shared_ptr<DataFeed>> readers;
PrepareReaders(readers, actual_thread_num, data_feed_desc, filelist);
#ifdef PADDLE_WITH_PSLIB
PrepareDenseThread(mode);
#endif
std::vector<std::shared_ptr<ExecutorThreadWorker>> workers;
workers.resize(actual_thread_num);
for (auto& worker : workers) {
#ifdef PADDLE_WITH_PSLIB
if (mode == "mpi") {
worker.reset(new AsyncExecutorThreadWorker);
} else {
worker.reset(new ExecutorThreadWorker);
}
#else
worker.reset(new ExecutorThreadWorker);
#endif
}
// prepare thread resource here
......@@ -306,9 +318,11 @@ void AsyncExecutor::RunFromFile(const ProgramDesc& main_program,
for (auto& th : threads) {
th.join();
}
#ifdef PADDLE_WITH_PSLIB
if (mode == "mpi") {
_pull_dense_thread->stop();
}
#endif
root_scope_->DropKids();
return;
......
......@@ -64,6 +64,7 @@ class AsyncExecutor {
const std::vector<std::string>& fetch_names,
const std::string& mode,
const bool debug = false);
#ifdef PADDLE_WITH_PSLIB
void InitServer(const std::string& dist_desc, int index);
void InitWorker(
const std::string& dist_desc,
......@@ -75,7 +76,7 @@ class AsyncExecutor {
void InitModel();
void SaveModel(const std::string& path);
void InitParamConfig();
#endif
private:
void CreateThreads(ExecutorThreadWorker* worker,
const ProgramDesc& main_program,
......@@ -83,16 +84,18 @@ class AsyncExecutor {
const std::vector<std::string>& fetch_var_names,
Scope* root_scope, const int thread_index,
const bool debug);
#ifdef PADDLE_WITH_PSLIB
void PrepareDenseThread(const std::string& mode);
#endif
public:
#ifdef PADDLE_WITH_PSLIB
std::shared_ptr<paddle::distributed::PSlib> _pslib_ptr;
std::shared_ptr<DensePullThread> _pull_dense_thread;
AsyncWorkerParamConfig _param_config;
#endif
Scope* root_scope_;
platform::Place place_;
AsyncWorkerParamConfig _param_config;
private:
int actual_thread_num;
......
......@@ -31,6 +31,7 @@ limitations under the License. */
namespace paddle {
namespace framework {
#ifdef PADDLE_WITH_PSLIB
int DensePullThread::start() {
_running = true;
_t = std::thread(&DensePullThread::run, this);
......@@ -112,6 +113,7 @@ void DensePullThread::increase_thread_version(
std::lock_guard<std::mutex> lock(_mutex_for_version);
_training_versions[table_id][thread_id]++;
}
#endif
void ExecutorThreadWorker::CreateThreadOperators(const ProgramDesc& program) {
auto& block = program.Block(0);
......@@ -302,6 +304,7 @@ void ExecutorThreadWorker::SetRootScope(Scope* g_scope) {
root_scope_ = g_scope;
}
#ifdef PADDLE_WITH_PSLIB
// AsyncExecutor
void AsyncExecutorThreadWorker::TrainFiles() {
SetDevice();
......@@ -659,6 +662,7 @@ void AsyncExecutorThreadWorker::check_pull_push_memory(
}
}
}
#endif
} // einit_modelnd namespace framework
} // end namespace paddle
......@@ -25,14 +25,16 @@ limitations under the License. */
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h"
#ifdef PADDLE_WITH_PSLIB
#include "pslib.h"
#endif
namespace paddle {
namespace framework {
const static uint32_t MAX_FEASIGN_NUM = 1000 * 100 * 100;
void CreateTensor(Variable* var, proto::VarType::Type var_type);
#ifdef PADDLE_WITH_PSLIB
const static uint32_t MAX_FEASIGN_NUM = 1000 * 100 * 100;
struct AsyncWorkerParamConfig {
int slot_dim;
......@@ -130,6 +132,8 @@ class DensePullThread {
float _total_batch_num = 0;
};
#endif
class ExecutorThreadWorker {
public:
ExecutorThreadWorker()
......@@ -154,12 +158,14 @@ class ExecutorThreadWorker {
virtual void TrainFiles();
// set fetch variable names from python interface assigned by users
void SetFetchVarNames(const std::vector<std::string>& fetch_var_names);
#ifdef PADDLE_WITH_PSLIB
virtual void SetPSlibPtr(
std::shared_ptr<paddle::distributed::PSlib> pslib_ptr) {};
virtual void SetPullDenseThread(
std::shared_ptr<DensePullThread> dpt) {}
virtual void SetParamConfig(
AsyncWorkerParamConfig * param_config) {}
#endif
private:
void CreateThreadScope(const framework::ProgramDesc& program);
......@@ -188,6 +194,7 @@ class ExecutorThreadWorker {
bool debug_;
};
#ifdef PADDLE_WITH_PSLIB
class AsyncExecutorThreadWorker: public ExecutorThreadWorker {
public:
AsyncExecutorThreadWorker() {}
......@@ -238,6 +245,7 @@ class AsyncExecutorThreadWorker: public ExecutorThreadWorker {
AsyncWorkerParamConfig* _param_config;
};
#endif
} // namespace framework
} // namespace paddle
......@@ -41,6 +41,7 @@ namespace pd = paddle::framework;
namespace paddle {
namespace pybind {
using set_name_func = void (pd::DataFeedDesc::*)(const std::string&);
#ifdef PADDLE_WITH_PSLIB
void BindAsyncExecutor(py::module* m) {
py::class_<framework::AsyncExecutor>(*m, "AsyncExecutor")
.def(py::init([](framework::Scope* scope, const platform::Place& place) {
......@@ -56,5 +57,15 @@ void BindAsyncExecutor(py::module* m) {
.def("init_model", &framework::AsyncExecutor::InitModel)
.def("save_model", &framework::AsyncExecutor::SaveModel);
} // end BindAsyncExecutor
#else
void BindAsyncExecutor(py::module* m) {
py::class_<framework::AsyncExecutor>(*m, "AsyncExecutor")
.def(py::init([](framework::Scope* scope, const platform::Place& place) {
return std::unique_ptr<framework::AsyncExecutor>(
new framework::AsyncExecutor(scope, place));
}))
.def("run_from_files", &framework::AsyncExecutor::RunFromFile)
} // end BindAsyncExecutor
#endif
} // end namespace pybind
} // end namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册