diff --git a/CMakeLists.txt b/CMakeLists.txt index 3d7f7b60a002e23702d5d11a819641c4e0c205f5..81a97265a358e09bd65d1c6b62b1ac727601946e 100755 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -160,6 +160,7 @@ option(WITH_BOX_PS "Compile with box_ps support" OFF) option(WITH_XBYAK "Compile with xbyak support" ON) option(WITH_CONTRIB "Compile the third-party contributation" OFF) option(WITH_GRPC "Use grpc as the default rpc framework" ${WITH_DISTRIBUTE}) +option(WITH_PSCORE "Compile with parameter server support" ${WITH_DISTRIBUTE}) option(WITH_INFERENCE_API_TEST "Test fluid inference C++ high-level api interface" OFF) option(PY_VERSION "Compile PaddlePaddle with python3 support" ${PY_VERSION}) option(WITH_DGC "Use DGC(Deep Gradient Compression) or not" ${WITH_DISTRIBUTE}) diff --git a/cmake/configure.cmake b/cmake/configure.cmake index df5c204eaec5cdaadec22f32edfb3d926531c86e..aeec7da2e6f0250276a607af321183983812d09c 100644 --- a/cmake/configure.cmake +++ b/cmake/configure.cmake @@ -160,6 +160,11 @@ if(WITH_DISTRIBUTE) add_definitions(-DPADDLE_WITH_DISTRIBUTE) endif() +if(WITH_PSCORE) + add_definitions(-DPADDLE_WITH_PSCORE) +endif() + + if(WITH_GRPC) add_definitions(-DPADDLE_WITH_GRPC) endif(WITH_GRPC) diff --git a/cmake/third_party.cmake b/cmake/third_party.cmake index 4ad2f84d33240558096bb43b274e988b9b2be210..84020f57f13e8a38d5716407c6403e874abc146e 100644 --- a/cmake/third_party.cmake +++ b/cmake/third_party.cmake @@ -274,7 +274,7 @@ if(WITH_BOX_PS) list(APPEND third_party_deps extern_box_ps) endif(WITH_BOX_PS) -if (WITH_DISTRIBUTE) +if (WITH_PSCORE) include(external/snappy) list(APPEND third_party_deps extern_snappy) diff --git a/paddle/fluid/distributed/CMakeLists.txt b/paddle/fluid/distributed/CMakeLists.txt index b9ad4e91ddc863603cebffeb41c7b2150a2b3aa1..5a2d7a06201ba4acff679ffcfee87fde8d025ed6 100644 --- a/paddle/fluid/distributed/CMakeLists.txt +++ b/paddle/fluid/distributed/CMakeLists.txt @@ -1,7 +1,4 @@ -if (WITH_PSLIB) - return() -endif() -if(NOT WITH_DISTRIBUTE) +if(NOT WITH_PSCORE) return() endif() diff --git a/paddle/fluid/distributed/common/registerer.h b/paddle/fluid/distributed/common/registerer.h index a4eab9c4a75e9ecabb183a9f41460a8b0cb516f6..630be930c14d9afc820bd58034ff5ae37751f2fb 100644 --- a/paddle/fluid/distributed/common/registerer.h +++ b/paddle/fluid/distributed/common/registerer.h @@ -69,24 +69,24 @@ class ObjectFactory { }; typedef std::map FactoryMap; -typedef std::map BaseClassMap; +typedef std::map PsCoreClassMap; #ifdef __cplusplus extern "C" { #endif -inline BaseClassMap &global_factory_map() { - static BaseClassMap *base_class = new BaseClassMap(); +inline PsCoreClassMap &global_factory_map() { + static PsCoreClassMap *base_class = new PsCoreClassMap(); return *base_class; } #ifdef __cplusplus } #endif -inline BaseClassMap &global_factory_map_cpp() { return global_factory_map(); } +inline PsCoreClassMap &global_factory_map_cpp() { return global_factory_map(); } // typedef pa::Any Any; // typedef ::FactoryMap FactoryMap; -#define REGISTER_REGISTERER(base_class) \ +#define REGISTER_PSCORE_REGISTERER(base_class) \ class base_class##Registerer { \ public: \ static base_class *CreateInstanceByName(const ::std::string &name) { \ @@ -107,7 +107,7 @@ inline BaseClassMap &global_factory_map_cpp() { return global_factory_map(); } } \ }; -#define REGISTER_CLASS(clazz, name) \ +#define REGISTER_PSCORE_CLASS(clazz, name) \ class ObjectFactory##name : public ObjectFactory { \ public: \ Any NewInstance() { return Any(new name()); } \ @@ -120,7 +120,7 @@ inline BaseClassMap &global_factory_map_cpp() { return global_factory_map(); } } \ void register_factory_##name() __attribute__((constructor)); -#define CREATE_CLASS(base_class, name) \ +#define CREATE_PSCORE_CLASS(base_class, name) \ base_class##Registerer::CreateInstanceByName(name); } // namespace distributed diff --git a/paddle/fluid/distributed/ps.proto b/paddle/fluid/distributed/ps.proto index 88ea04667f7018a19bd786bfffacd101a03084e0..2570d3eaf037013de2ee3ecd47230c7f87c40de1 100644 --- a/paddle/fluid/distributed/ps.proto +++ b/paddle/fluid/distributed/ps.proto @@ -86,7 +86,7 @@ message SparseTableParameter { message ServerServiceParameter { optional string server_class = 1 [ default = "BrpcPsServer" ]; optional string client_class = 2 [ default = "BrpcPsClient" ]; - optional string service_class = 3 [ default = "PsService" ]; + optional string service_class = 3 [ default = "BrpcPsService" ]; optional uint32 start_server_port = 4 [ default = 0 ]; // will find a avaliable port from it optional uint32 server_thread_num = 5 [ default = 12 ]; diff --git a/paddle/fluid/distributed/service/brpc_ps_client.cc b/paddle/fluid/distributed/service/brpc_ps_client.cc index 6f932d55e9a194785bc2e950a75db3c1857d5561..4a07c54375ae1a94ddc104132c414deba997088c 100644 --- a/paddle/fluid/distributed/service/brpc_ps_client.cc +++ b/paddle/fluid/distributed/service/brpc_ps_client.cc @@ -17,8 +17,8 @@ #include #include #include - #include "Eigen/Dense" + #include "paddle/fluid/distributed/service/brpc_ps_client.h" #include "paddle/fluid/distributed/table/table.h" #include "paddle/fluid/framework/archive.h" @@ -80,8 +80,8 @@ inline size_t get_sparse_shard(uint32_t shard_num, uint32_t server_num, void DownpourPsClientService::service( ::google::protobuf::RpcController *controller, - const ::paddle::PsRequestMessage *request, - ::paddle::PsResponseMessage *response, ::google::protobuf::Closure *done) { + const PsRequestMessage *request, PsResponseMessage *response, + ::google::protobuf::Closure *done) { brpc::ClosureGuard done_guard(done); int ret = _client->handle_client2client_msg( request->cmd_id(), request->client_id(), request->data()); diff --git a/paddle/fluid/distributed/service/brpc_ps_client.h b/paddle/fluid/distributed/service/brpc_ps_client.h index 50faf7c9771c58dde24384d179d01b734f38eabf..17a5d53e229dcb3937d93fffb3a0abf1b2678dc1 100644 --- a/paddle/fluid/distributed/service/brpc_ps_client.h +++ b/paddle/fluid/distributed/service/brpc_ps_client.h @@ -40,8 +40,8 @@ class DownpourPsClientService : public PsService { return 0; } virtual void service(::google::protobuf::RpcController *controller, - const ::paddle::PsRequestMessage *request, - ::paddle::PsResponseMessage *response, + const PsRequestMessage *request, + PsResponseMessage *response, ::google::protobuf::Closure *done) override; protected: diff --git a/paddle/fluid/distributed/service/brpc_ps_server.cc b/paddle/fluid/distributed/service/brpc_ps_server.cc index 914b9971cbf948da2403452f0e4b04a0779dbd6b..92a317d4e48d667386d02a5baea00c6992739e0e 100644 --- a/paddle/fluid/distributed/service/brpc_ps_server.cc +++ b/paddle/fluid/distributed/service/brpc_ps_server.cc @@ -13,6 +13,7 @@ // limitations under the License. #include "paddle/fluid/distributed/service/brpc_ps_server.h" + #include // NOLINT #include "Eigen/Dense" #include "butil/endpoint.h" @@ -30,7 +31,8 @@ int32_t BrpcPsServer::initialize() { LOG(ERROR) << "miss service_class in ServerServiceParameter"; return -1; } - auto *service = CREATE_CLASS(PsBaseService, service_config.service_class()); + auto *service = + CREATE_PSCORE_CLASS(PsBaseService, service_config.service_class()); if (service == NULL) { LOG(ERROR) << "service is unregistered, service_name:" << service_config.service_class(); @@ -79,28 +81,28 @@ uint64_t BrpcPsServer::start(const std::string &ip, uint32_t port) { int32_t BrpcPsServer::port() { return _server.listen_address().port; } -int32_t PsService::initialize() { +int32_t BrpcPsService::initialize() { _is_initialize_shard_info = false; - _service_handler_map[PS_STOP_SERVER] = &PsService::stop_server; - _service_handler_map[PS_PULL_DENSE_TABLE] = &PsService::pull_dense; - _service_handler_map[PS_PUSH_DENSE_TABLE] = &PsService::push_dense; - _service_handler_map[PS_PULL_SPARSE_TABLE] = &PsService::pull_sparse; - _service_handler_map[PS_PUSH_SPARSE_TABLE] = &PsService::push_sparse; - _service_handler_map[PS_SAVE_ONE_TABLE] = &PsService::save_one_table; - _service_handler_map[PS_SAVE_ALL_TABLE] = &PsService::save_all_table; - _service_handler_map[PS_SHRINK_TABLE] = &PsService::shrink_table; - _service_handler_map[PS_LOAD_ONE_TABLE] = &PsService::load_one_table; - _service_handler_map[PS_LOAD_ALL_TABLE] = &PsService::load_all_table; - _service_handler_map[PS_CLEAR_ONE_TABLE] = &PsService::clear_one_table; - _service_handler_map[PS_CLEAR_ALL_TABLE] = &PsService::clear_all_table; - _service_handler_map[PS_PUSH_DENSE_PARAM] = &PsService::push_dense_param; - _service_handler_map[PS_PRINT_TABLE_STAT] = &PsService::print_table_stat; - _service_handler_map[PS_PULL_GEO_PARAM] = &PsService::pull_geo_param; - _service_handler_map[PS_PUSH_SPARSE_PARAM] = &PsService::push_sparse_param; - _service_handler_map[PS_BARRIER] = &PsService::barrier; - _service_handler_map[PS_START_PROFILER] = &PsService::start_profiler; - _service_handler_map[PS_STOP_PROFILER] = &PsService::stop_profiler; - _service_handler_map[PS_PUSH_GLOBAL_STEP] = &PsService::push_global_step; + _service_handler_map[PS_STOP_SERVER] = &BrpcPsService::stop_server; + _service_handler_map[PS_PULL_DENSE_TABLE] = &BrpcPsService::pull_dense; + _service_handler_map[PS_PUSH_DENSE_TABLE] = &BrpcPsService::push_dense; + _service_handler_map[PS_PULL_SPARSE_TABLE] = &BrpcPsService::pull_sparse; + _service_handler_map[PS_PUSH_SPARSE_TABLE] = &BrpcPsService::push_sparse; + _service_handler_map[PS_SAVE_ONE_TABLE] = &BrpcPsService::save_one_table; + _service_handler_map[PS_SAVE_ALL_TABLE] = &BrpcPsService::save_all_table; + _service_handler_map[PS_SHRINK_TABLE] = &BrpcPsService::shrink_table; + _service_handler_map[PS_LOAD_ONE_TABLE] = &BrpcPsService::load_one_table; + _service_handler_map[PS_LOAD_ALL_TABLE] = &BrpcPsService::load_all_table; + _service_handler_map[PS_CLEAR_ONE_TABLE] = &BrpcPsService::clear_one_table; + _service_handler_map[PS_CLEAR_ALL_TABLE] = &BrpcPsService::clear_all_table; + _service_handler_map[PS_PUSH_DENSE_PARAM] = &BrpcPsService::push_dense_param; + _service_handler_map[PS_PRINT_TABLE_STAT] = &BrpcPsService::print_table_stat; + _service_handler_map[PS_PULL_GEO_PARAM] = &BrpcPsService::pull_geo_param; + _service_handler_map[PS_PUSH_SPARSE_PARAM] = + &BrpcPsService::push_sparse_param; + _service_handler_map[PS_BARRIER] = &BrpcPsService::barrier; + _service_handler_map[PS_START_PROFILER] = &BrpcPsService::start_profiler; + _service_handler_map[PS_STOP_PROFILER] = &BrpcPsService::stop_profiler; // shard初始化,server启动后才可从env获取到server_list的shard信息 initialize_shard_info(); @@ -116,7 +118,7 @@ int32_t PsService::initialize() { return -1; \ } -int32_t PsService::initialize_shard_info() { +int32_t BrpcPsService::initialize_shard_info() { if (!_is_initialize_shard_info) { std::lock_guard guard(_initialize_shard_mutex); if (_is_initialize_shard_info) { @@ -132,10 +134,10 @@ int32_t PsService::initialize_shard_info() { return 0; } -void PsService::service(google::protobuf::RpcController *cntl_base, - const PsRequestMessage *request, - PsResponseMessage *response, - google::protobuf::Closure *done) { +void BrpcPsService::service(google::protobuf::RpcController *cntl_base, + const PsRequestMessage *request, + PsResponseMessage *response, + google::protobuf::Closure *done) { brpc::ClosureGuard done_guard(done); std::string log_label("ReceiveCmd-"); if (!request->has_table_id()) { @@ -163,9 +165,9 @@ void PsService::service(google::protobuf::RpcController *cntl_base, } } -int32_t PsService::pull_dense(Table *table, const PsRequestMessage &request, - PsResponseMessage &response, - brpc::Controller *cntl) { +int32_t BrpcPsService::pull_dense(Table *table, const PsRequestMessage &request, + PsResponseMessage &response, + brpc::Controller *cntl) { platform::RecordEvent record_event("PsService->pull_dense"); CHECK_TABLE_EXIST(table, request, response) if (request.params_size() < 1) { @@ -191,10 +193,10 @@ int32_t PsService::pull_dense(Table *table, const PsRequestMessage &request, return 0; } -int32_t PsService::push_dense_param(Table *table, - const PsRequestMessage &request, - PsResponseMessage &response, - brpc::Controller *cntl) { +int32_t BrpcPsService::push_dense_param(Table *table, + const PsRequestMessage &request, + PsResponseMessage &response, + brpc::Controller *cntl) { platform::RecordEvent record_event("PsService->push_dense_param"); CHECK_TABLE_EXIST(table, request, response) thread_local std::string push_buffer; @@ -218,9 +220,9 @@ int32_t PsService::push_dense_param(Table *table, return 0; } -int32_t PsService::push_dense(Table *table, const PsRequestMessage &request, - PsResponseMessage &response, - brpc::Controller *cntl) { +int32_t BrpcPsService::push_dense(Table *table, const PsRequestMessage &request, + PsResponseMessage &response, + brpc::Controller *cntl) { platform::RecordEvent record_event("PsService->push_dense"); CHECK_TABLE_EXIST(table, request, response) auto req_buffer_size = request.data().size(); @@ -244,9 +246,9 @@ int32_t PsService::push_dense(Table *table, const PsRequestMessage &request, return 0; } -int32_t PsService::barrier(Table *table, const PsRequestMessage &request, - PsResponseMessage &response, - brpc::Controller *cntl) { +int32_t BrpcPsService::barrier(Table *table, const PsRequestMessage &request, + PsResponseMessage &response, + brpc::Controller *cntl) { CHECK_TABLE_EXIST(table, request, response) if (request.params_size() < 1) { @@ -262,10 +264,10 @@ int32_t PsService::barrier(Table *table, const PsRequestMessage &request, return 0; } -int32_t PsService::push_sparse_param(Table *table, - const PsRequestMessage &request, - PsResponseMessage &response, - brpc::Controller *cntl) { +int32_t BrpcPsService::push_sparse_param(Table *table, + const PsRequestMessage &request, + PsResponseMessage &response, + brpc::Controller *cntl) { platform::RecordEvent record_event("PsService->push_sparse_param"); CHECK_TABLE_EXIST(table, request, response) auto &push_data = request.data(); @@ -294,9 +296,10 @@ int32_t PsService::push_sparse_param(Table *table, return 0; } -int32_t PsService::pull_geo_param(Table *table, const PsRequestMessage &request, - PsResponseMessage &response, - brpc::Controller *cntl) { +int32_t BrpcPsService::pull_geo_param(Table *table, + const PsRequestMessage &request, + PsResponseMessage &response, + brpc::Controller *cntl) { platform::RecordEvent record_event("PsService->pull_geo_param"); CHECK_TABLE_EXIST(table, request, response) thread_local std::string push_sparse_request_buffer; @@ -316,9 +319,10 @@ int32_t PsService::pull_geo_param(Table *table, const PsRequestMessage &request, return 0; } -int32_t PsService::pull_sparse(Table *table, const PsRequestMessage &request, - PsResponseMessage &response, - brpc::Controller *cntl) { +int32_t BrpcPsService::pull_sparse(Table *table, + const PsRequestMessage &request, + PsResponseMessage &response, + brpc::Controller *cntl) { platform::RecordEvent record_event("PsService->pull_sparse"); CHECK_TABLE_EXIST(table, request, response) thread_local std::string push_sparse_request_buffer; @@ -353,9 +357,10 @@ int32_t PsService::pull_sparse(Table *table, const PsRequestMessage &request, return 0; } -int32_t PsService::push_sparse(Table *table, const PsRequestMessage &request, - PsResponseMessage &response, - brpc::Controller *cntl) { +int32_t BrpcPsService::push_sparse(Table *table, + const PsRequestMessage &request, + PsResponseMessage &response, + brpc::Controller *cntl) { platform::RecordEvent record_event("PsService->push_sparse"); CHECK_TABLE_EXIST(table, request, response) auto &push_data = request.data(); @@ -384,10 +389,10 @@ int32_t PsService::push_sparse(Table *table, const PsRequestMessage &request, return 0; } -int32_t PsService::print_table_stat(Table *table, - const PsRequestMessage &request, - PsResponseMessage &response, - brpc::Controller *cntl) { +int32_t BrpcPsService::print_table_stat(Table *table, + const PsRequestMessage &request, + PsResponseMessage &response, + brpc::Controller *cntl) { CHECK_TABLE_EXIST(table, request, response) std::pair ret = table->print_table_stat(); paddle::framework::BinaryArchive ar; @@ -398,9 +403,10 @@ int32_t PsService::print_table_stat(Table *table, return 0; } -int32_t PsService::load_one_table(Table *table, const PsRequestMessage &request, - PsResponseMessage &response, - brpc::Controller *cntl) { +int32_t BrpcPsService::load_one_table(Table *table, + const PsRequestMessage &request, + PsResponseMessage &response, + brpc::Controller *cntl) { CHECK_TABLE_EXIST(table, request, response) if (request.params_size() < 2) { set_response_code( @@ -415,9 +421,10 @@ int32_t PsService::load_one_table(Table *table, const PsRequestMessage &request, return 0; } -int32_t PsService::load_all_table(Table *table, const PsRequestMessage &request, - PsResponseMessage &response, - brpc::Controller *cntl) { +int32_t BrpcPsService::load_all_table(Table *table, + const PsRequestMessage &request, + PsResponseMessage &response, + brpc::Controller *cntl) { auto &table_map = *(_server->table()); for (auto &itr : table_map) { if (load_one_table(itr.second.get(), request, response, cntl) != 0) { @@ -428,9 +435,10 @@ int32_t PsService::load_all_table(Table *table, const PsRequestMessage &request, return 0; } -int32_t PsService::save_one_table(Table *table, const PsRequestMessage &request, - PsResponseMessage &response, - brpc::Controller *cntl) { +int32_t BrpcPsService::save_one_table(Table *table, + const PsRequestMessage &request, + PsResponseMessage &response, + brpc::Controller *cntl) { CHECK_TABLE_EXIST(table, request, response) if (request.params_size() < 2) { set_response_code( @@ -449,9 +457,10 @@ int32_t PsService::save_one_table(Table *table, const PsRequestMessage &request, return feasign_size; } -int32_t PsService::save_all_table(Table *table, const PsRequestMessage &request, - PsResponseMessage &response, - brpc::Controller *cntl) { +int32_t BrpcPsService::save_all_table(Table *table, + const PsRequestMessage &request, + PsResponseMessage &response, + brpc::Controller *cntl) { auto &table_map = *(_server->table()); int32_t all_feasign_size = 0; int32_t feasign_size = 0; @@ -466,9 +475,10 @@ int32_t PsService::save_all_table(Table *table, const PsRequestMessage &request, return 0; } -int32_t PsService::shrink_table(Table *table, const PsRequestMessage &request, - PsResponseMessage &response, - brpc::Controller *cntl) { +int32_t BrpcPsService::shrink_table(Table *table, + const PsRequestMessage &request, + PsResponseMessage &response, + brpc::Controller *cntl) { CHECK_TABLE_EXIST(table, request, response) table->flush(); if (table->shrink() != 0) { @@ -477,20 +487,20 @@ int32_t PsService::shrink_table(Table *table, const PsRequestMessage &request, return 0; } -int32_t PsService::clear_one_table(Table *table, - const PsRequestMessage &request, - PsResponseMessage &response, - brpc::Controller *cntl) { +int32_t BrpcPsService::clear_one_table(Table *table, + const PsRequestMessage &request, + PsResponseMessage &response, + brpc::Controller *cntl) { CHECK_TABLE_EXIST(table, request, response) table->flush(); table->clear(); return 0; } -int32_t PsService::clear_all_table(Table *table, - const PsRequestMessage &request, - PsResponseMessage &response, - brpc::Controller *cntl) { +int32_t BrpcPsService::clear_all_table(Table *table, + const PsRequestMessage &request, + PsResponseMessage &response, + brpc::Controller *cntl) { auto &table_map = *(_server->table()); for (auto &itr : table_map) { if (clear_one_table(itr.second.get(), request, response, cntl) != 0) { @@ -500,9 +510,10 @@ int32_t PsService::clear_all_table(Table *table, return 0; } -int32_t PsService::stop_server(Table *table, const PsRequestMessage &request, - PsResponseMessage &response, - brpc::Controller *cntl) { +int32_t BrpcPsService::stop_server(Table *table, + const PsRequestMessage &request, + PsResponseMessage &response, + brpc::Controller *cntl) { auto *p_server = _server; std::thread t_stop([p_server]() { p_server->stop(); @@ -512,25 +523,27 @@ int32_t PsService::stop_server(Table *table, const PsRequestMessage &request, return 0; } -int32_t PsService::stop_profiler(Table *table, const PsRequestMessage &request, - PsResponseMessage &response, - brpc::Controller *cntl) { +int32_t BrpcPsService::stop_profiler(Table *table, + const PsRequestMessage &request, + PsResponseMessage &response, + brpc::Controller *cntl) { platform::DisableProfiler(platform::EventSortingKey::kDefault, string::Sprintf("server_%s_profile", _rank)); return 0; } -int32_t PsService::start_profiler(Table *table, const PsRequestMessage &request, - PsResponseMessage &response, - brpc::Controller *cntl) { +int32_t BrpcPsService::start_profiler(Table *table, + const PsRequestMessage &request, + PsResponseMessage &response, + brpc::Controller *cntl) { platform::EnableProfiler(platform::ProfilerState::kCPU); return 0; } -int32_t PsService::push_global_step(Table *table, - const PsRequestMessage &request, - PsResponseMessage &response, - brpc::Controller *cntl) { +int32_t BrpcPsService::push_global_step(Table *table, + const PsRequestMessage &request, + PsResponseMessage &response, + brpc::Controller *cntl) { CHECK_TABLE_EXIST(table, request, response); auto req_buffer_size = request.data().size(); if (req_buffer_size < 1) { diff --git a/paddle/fluid/distributed/service/brpc_ps_server.h b/paddle/fluid/distributed/service/brpc_ps_server.h index e9eeb5d49c71705c139439afadfe2e18c680930b..c2d0641743a95ed728958024c3b923b5a3253cef 100644 --- a/paddle/fluid/distributed/service/brpc_ps_server.h +++ b/paddle/fluid/distributed/service/brpc_ps_server.h @@ -52,19 +52,19 @@ class BrpcPsServer : public PSServer { std::vector> _pserver_channels; }; -class PsService; +class BrpcPsService; -typedef int32_t (PsService::*serviceHandlerFunc)( +typedef int32_t (BrpcPsService::*serviceHandlerFunc)( Table *table, const PsRequestMessage &request, PsResponseMessage &response, brpc::Controller *cntl); -class PsService : public PsBaseService { +class BrpcPsService : public PsBaseService { public: virtual int32_t initialize() override; virtual void service(::google::protobuf::RpcController *controller, - const ::paddle::PsRequestMessage *request, - ::paddle::PsResponseMessage *response, + const PsRequestMessage *request, + PsResponseMessage *response, ::google::protobuf::Closure *done) override; private: diff --git a/paddle/fluid/distributed/service/brpc_utils.cc b/paddle/fluid/distributed/service/brpc_utils.cc index abd58bf028c2c19e50d18d8b33ff34e2b92e2d3f..82ec10b327197d1d9d86f212f724a2130edef5fc 100644 --- a/paddle/fluid/distributed/service/brpc_utils.cc +++ b/paddle/fluid/distributed/service/brpc_utils.cc @@ -88,7 +88,7 @@ void SerializeLodTensor(framework::Variable* var, const platform::DeviceContext& ctx, VarMsg* var_msg, butil::IOBuf* iobuf) { auto* tensor = var->GetMutable(); - var_msg->set_type(::paddle::LOD_TENSOR); + var_msg->set_type(::paddle::distributed::LOD_TENSOR); const framework::LoD lod = tensor->lod(); if (lod.size() > 0) { var_msg->set_lod_level(lod.size()); @@ -135,7 +135,7 @@ void SerializeSelectedRows(framework::Variable* var, auto* tensor = slr->mutable_value(); auto* rows = slr->mutable_rows(); - var_msg->set_type(::paddle::SELECTED_ROWS); + var_msg->set_type(::paddle::distributed::SELECTED_ROWS); var_msg->set_slr_height(slr->height()); auto* var_data = var_msg->mutable_data(); @@ -194,9 +194,9 @@ void DeserializeFromMultiVarMsgAndIOBuf(const MultiVarMsg& multi_msg, ++recv_var_index) { const auto& msg = multi_msg.var_messages(recv_var_index); auto* var = scope->Var(msg.varname()); - if (msg.type() == ::paddle::LOD_TENSOR) { + if (msg.type() == ::paddle::distributed::LOD_TENSOR) { DeserializeLodTensor(var, msg, io_buffer_itr, ctx); - } else if (msg.type() == ::paddle::SELECTED_ROWS) { + } else if (msg.type() == ::paddle::distributed::SELECTED_ROWS) { DeserializeSelectedRows(var, msg, io_buffer_itr, ctx); } } @@ -215,9 +215,9 @@ void DeserializeFromMultiVarMsgAndIOBuf(const MultiVarMsg& multi_msg, PADDLE_ENFORCE_NE(var, nullptr, platform::errors::InvalidArgument( "Not find variable %s in scope.", msg.varname())); - if (msg.type() == ::paddle::LOD_TENSOR) { + if (msg.type() == ::paddle::distributed::LOD_TENSOR) { DeserializeLodTensor(var, msg, io_buffer_itr, ctx); - } else if (msg.type() == ::paddle::SELECTED_ROWS) { + } else if (msg.type() == ::paddle::distributed::SELECTED_ROWS) { DeserializeSelectedRows(var, msg, io_buffer_itr, ctx); } } diff --git a/paddle/fluid/distributed/service/brpc_utils.h b/paddle/fluid/distributed/service/brpc_utils.h index aa340c58a7b8b0ed93d1dd67cd747689be9fe094..6f00adb94a9ddcaa48125a9194291fe0351b1198 100644 --- a/paddle/fluid/distributed/service/brpc_utils.h +++ b/paddle/fluid/distributed/service/brpc_utils.h @@ -44,8 +44,8 @@ class DeviceContext; namespace paddle { namespace distributed { -using MultiVarMsg = ::paddle::MultiVariableMessage; -using VarMsg = ::paddle::VariableMessage; +using MultiVarMsg = ::paddle::distributed::MultiVariableMessage; +using VarMsg = ::paddle::distributed::VariableMessage; void SerializeToMultiVarMsgAndIOBuf( const std::string& message_name, diff --git a/paddle/fluid/distributed/service/heter_client.cc b/paddle/fluid/distributed/service/heter_client.cc index 311385825b240f2d6c8698999abb83f0c2efc8b1..99def0aef8eeed8e48ea618fbdd516bb690e77c8 100644 --- a/paddle/fluid/distributed/service/heter_client.cc +++ b/paddle/fluid/distributed/service/heter_client.cc @@ -122,7 +122,7 @@ void HeterClient::SendAndRecvAsync( cntl.set_timeout_ms(FLAGS_pserver_timeout_ms); distributed::MultiVarMsg request, response; auto& request_io_buffer = cntl.request_attachment(); - ::paddle::PsService_Stub stub(xpu_channels_[num].get()); + ::paddle::distributed::PsService_Stub stub(xpu_channels_[num].get()); distributed::SerializeToMultiVarMsgAndIOBuf( message_name_val, send_var_name_val, recv_var_name_val, *p_ctx, p_scope, &request, &request_io_buffer); @@ -164,7 +164,7 @@ std::future HeterClient::SendCmd( for (const auto& param : params) { closure->request(i)->add_params(param); } - ::paddle::PsService_Stub rpc_stub(xpu_channels_[i].get()); + ::paddle::distributed::PsService_Stub rpc_stub(xpu_channels_[i].get()); closure->cntl(i)->set_timeout_ms( FLAGS_pserver_timeout_ms); // cmd msg don't limit timeout for save/load rpc_stub.service(closure->cntl(i), closure->request(i), diff --git a/paddle/fluid/distributed/service/heter_client.h b/paddle/fluid/distributed/service/heter_client.h index 0abbe284940443eb3c911ba850fd7f9551814095..a3490281c225576c967af6f70b75c70ff1320ae3 100644 --- a/paddle/fluid/distributed/service/heter_client.h +++ b/paddle/fluid/distributed/service/heter_client.h @@ -35,8 +35,8 @@ limitations under the License. */ namespace paddle { namespace distributed { -using MultiVarMsg = ::paddle::MultiVariableMessage; -using VarMsg = ::paddle::VariableMessage; +using MultiVarMsg = ::paddle::distributed::MultiVariableMessage; +using VarMsg = ::paddle::distributed::VariableMessage; typedef std::function HeterRpcCallbackFunc; diff --git a/paddle/fluid/distributed/service/heter_server.h b/paddle/fluid/distributed/service/heter_server.h index 04b122d8d2756ffa5cfdfe096bf77aca1843222c..c1c6478787fcb6af451d9b124cb65359e69bdf18 100644 --- a/paddle/fluid/distributed/service/heter_server.h +++ b/paddle/fluid/distributed/service/heter_server.h @@ -39,8 +39,8 @@ DECLARE_double(eager_delete_tensor_gb); namespace paddle { namespace distributed { -using MultiVarMsg = ::paddle::MultiVariableMessage; -using VarMsg = ::paddle::VariableMessage; +using MultiVarMsg = ::paddle::distributed::MultiVariableMessage; +using VarMsg = ::paddle::distributed::VariableMessage; class HeterService; typedef int32_t (HeterService::*serviceHandlerFunc)( @@ -51,7 +51,7 @@ typedef std::function HeterRpcCallbackFunc; typedef std::function HeterServiceHandler; -class HeterService : public ::paddle::PsService { +class HeterService : public ::paddle::distributed::PsService { public: HeterService() { _service_handler_map[PS_STOP_SERVER] = &HeterService::stop_heter_worker; @@ -62,8 +62,8 @@ class HeterService : public ::paddle::PsService { virtual ~HeterService() {} virtual void service(::google::protobuf::RpcController* controller, - const ::paddle::PsRequestMessage* request, - ::paddle::PsResponseMessage* response, + const PsRequestMessage* request, + PsResponseMessage* response, ::google::protobuf::Closure* done) { brpc::ClosureGuard done_guard(done); std::string log_label("ReceiveCmd-"); diff --git a/paddle/fluid/distributed/service/ps_client.cc b/paddle/fluid/distributed/service/ps_client.cc index dd5fb9c24b32cebd36f19822d97bde56171dac6d..866200e7740f1fc0d3ee1ade411ddc100b3c51f3 100644 --- a/paddle/fluid/distributed/service/ps_client.cc +++ b/paddle/fluid/distributed/service/ps_client.cc @@ -13,9 +13,7 @@ // limitations under the License. #include "paddle/fluid/distributed/service/ps_client.h" - #include - #include "brpc/server.h" #include "glog/logging.h" #include "paddle/fluid/distributed/service/brpc_ps_client.h" @@ -23,7 +21,7 @@ namespace paddle { namespace distributed { -REGISTER_CLASS(PSClient, BrpcPsClient); +REGISTER_PSCORE_CLASS(PSClient, BrpcPsClient); int32_t PSClient::configure( const PSParameter &config, @@ -43,7 +41,7 @@ int32_t PSClient::configure( const auto &work_param = _config.worker_param().downpour_worker_param(); for (size_t i = 0; i < work_param.downpour_table_param_size(); ++i) { - auto *accessor = CREATE_CLASS( + auto *accessor = CREATE_PSCORE_CLASS( ValueAccessor, work_param.downpour_table_param(i).accessor().accessor_class()); accessor->configure(work_param.downpour_table_param(i).accessor()); @@ -73,7 +71,8 @@ PSClient *PSClientFactory::create(const PSParameter &ps_config) { } const auto &service_param = config.downpour_server_param().service_param(); - PSClient *client = CREATE_CLASS(PSClient, service_param.client_class()); + PSClient *client = + CREATE_PSCORE_CLASS(PSClient, service_param.client_class()); if (client == NULL) { LOG(ERROR) << "client is not registered, server_name:" << service_param.client_class(); diff --git a/paddle/fluid/distributed/service/ps_client.h b/paddle/fluid/distributed/service/ps_client.h index 9d2309faef152c0b1793467eabda44fe2f44d1fa..a23a06c46e0a2e1a5bf1c0eafb073870ee43ddcc 100644 --- a/paddle/fluid/distributed/service/ps_client.h +++ b/paddle/fluid/distributed/service/ps_client.h @@ -28,6 +28,9 @@ namespace paddle { namespace distributed { +using paddle::distributed::PsRequestMessage; +using paddle::distributed::PsResponseMessage; + typedef std::function PSClientCallBack; class PSClientClosure : public google::protobuf::Closure { public: @@ -206,7 +209,7 @@ class PSClient { std::unordered_map _msg_handler_map; //处理client2client消息 }; -REGISTER_REGISTERER(PSClient); +REGISTER_PSCORE_REGISTERER(PSClient); class PSClientFactory { public: diff --git a/paddle/fluid/distributed/service/sendrecv.proto b/paddle/fluid/distributed/service/sendrecv.proto index 0cd849ced51db4b1a882f3b5c61cc76855e8e043..6250f84c98754d31b6f0a4cf6689e4a560549f2c 100644 --- a/paddle/fluid/distributed/service/sendrecv.proto +++ b/paddle/fluid/distributed/service/sendrecv.proto @@ -13,7 +13,7 @@ // limitations under the License. syntax = "proto2"; -package paddle; +package paddle.distributed; option cc_generic_services = true; option cc_enable_arenas = true; diff --git a/paddle/fluid/distributed/service/server.cc b/paddle/fluid/distributed/service/server.cc index fe5ee120dd1ecab0671826be8f0e98bbc6549d42..fc230a0b9c92e646f3dc87231effb7462f2340b6 100644 --- a/paddle/fluid/distributed/service/server.cc +++ b/paddle/fluid/distributed/service/server.cc @@ -13,6 +13,7 @@ // limitations under the License. #include "paddle/fluid/distributed/service/server.h" + #include "glog/logging.h" #include "paddle/fluid/distributed/service/brpc_ps_server.h" #include "paddle/fluid/distributed/table/table.h" @@ -20,8 +21,8 @@ namespace paddle { namespace distributed { -REGISTER_CLASS(PSServer, BrpcPsServer); -REGISTER_CLASS(PsBaseService, PsService); +REGISTER_PSCORE_CLASS(PSServer, BrpcPsServer); +REGISTER_PSCORE_CLASS(PsBaseService, BrpcPsService); PSServer *PSServerFactory::create(const PSParameter &ps_config) { const auto &config = ps_config.server_param(); @@ -43,7 +44,8 @@ PSServer *PSServerFactory::create(const PSParameter &ps_config) { } const auto &service_param = config.downpour_server_param().service_param(); - PSServer *server = CREATE_CLASS(PSServer, service_param.server_class()); + PSServer *server = + CREATE_PSCORE_CLASS(PSServer, service_param.server_class()); if (server == NULL) { LOG(ERROR) << "server is not registered, server_name:" << service_param.server_class(); @@ -70,7 +72,7 @@ int32_t PSServer::configure( uint32_t global_step_table = UINT32_MAX; for (size_t i = 0; i < downpour_param.downpour_table_param_size(); ++i) { - auto *table = CREATE_CLASS( + auto *table = CREATE_PSCORE_CLASS( Table, downpour_param.downpour_table_param(i).table_class()); if (downpour_param.downpour_table_param(i).table_class() == diff --git a/paddle/fluid/distributed/service/server.h b/paddle/fluid/distributed/service/server.h index 532f458e436d2860b53287b274ee849ac9c5538f..78741b8cf80f357933d0b1cbdc9e1887dcae9812 100644 --- a/paddle/fluid/distributed/service/server.h +++ b/paddle/fluid/distributed/service/server.h @@ -46,6 +46,8 @@ namespace paddle { namespace distributed { class Table; +using paddle::distributed::PsRequestMessage; +using paddle::distributed::PsResponseMessage; class PSServer { public: @@ -107,7 +109,7 @@ class PSServer { platform::Place place_ = platform::CPUPlace(); }; -REGISTER_REGISTERER(PSServer); +REGISTER_PSCORE_REGISTERER(PSServer); typedef std::function PServerCallBack; @@ -141,8 +143,8 @@ class PsBaseService : public PsService { return 0; } virtual void service(::google::protobuf::RpcController *controller, - const ::paddle::PsRequestMessage *request, - ::paddle::PsResponseMessage *response, + const PsRequestMessage *request, + PsResponseMessage *response, ::google::protobuf::Closure *done) override = 0; virtual void set_response_code(PsResponseMessage &response, int err_code, @@ -159,7 +161,7 @@ class PsBaseService : public PsService { PSServer *_server; const ServerParameter *_config; }; -REGISTER_REGISTERER(PsBaseService); +REGISTER_PSCORE_REGISTERER(PsBaseService); class PSServerFactory { public: diff --git a/paddle/fluid/distributed/service/service.h b/paddle/fluid/distributed/service/service.h index 539638c803f2cf1db181880268505465da3df836..b4ba691cced5feabc549238c3412203dee11f1c2 100644 --- a/paddle/fluid/distributed/service/service.h +++ b/paddle/fluid/distributed/service/service.h @@ -28,6 +28,10 @@ limitations under the License. */ namespace paddle { namespace distributed { +using paddle::distributed::PsRequestMessage; +using paddle::distributed::PsResponseMessage; +using paddle::distributed::PsService; + class PSCore { public: explicit PSCore() {} diff --git a/paddle/fluid/distributed/table/accessor.h b/paddle/fluid/distributed/table/accessor.h index a07a8e10b16f64539a83ebf55bbe4c43dbb7fef2..7cc92ce98ba6962b5c176a0efb7c5907d8ca7394 100644 --- a/paddle/fluid/distributed/table/accessor.h +++ b/paddle/fluid/distributed/table/accessor.h @@ -165,6 +165,6 @@ class ValueAccessor { std::unordered_map> _data_coverter_map; }; -REGISTER_REGISTERER(ValueAccessor); +REGISTER_PSCORE_REGISTERER(ValueAccessor); } // namespace distributed } // namespace paddle diff --git a/paddle/fluid/distributed/table/table.cc b/paddle/fluid/distributed/table/table.cc index ec08dc58da22e0773663de7c46e926b2205ece0a..31a2399aa35f7e9e6d2f26885889f0171a505965 100644 --- a/paddle/fluid/distributed/table/table.cc +++ b/paddle/fluid/distributed/table/table.cc @@ -13,6 +13,7 @@ // limitations under the License. #include "paddle/fluid/distributed/table/table.h" + #include #include #include "glog/logging.h" @@ -27,14 +28,14 @@ namespace paddle { namespace distributed { -REGISTER_CLASS(Table, CommonDenseTable); -REGISTER_CLASS(Table, CommonSparseTable); -REGISTER_CLASS(Table, SparseGeoTable); -REGISTER_CLASS(Table, BarrierTable); -REGISTER_CLASS(Table, TensorTable); -REGISTER_CLASS(Table, DenseTensorTable); -REGISTER_CLASS(Table, GlobalStepTable); -REGISTER_CLASS(ValueAccessor, CommMergeAccessor); +REGISTER_PSCORE_CLASS(Table, CommonDenseTable); +REGISTER_PSCORE_CLASS(Table, CommonSparseTable); +REGISTER_PSCORE_CLASS(Table, SparseGeoTable); +REGISTER_PSCORE_CLASS(Table, BarrierTable); +REGISTER_PSCORE_CLASS(Table, TensorTable); +REGISTER_PSCORE_CLASS(Table, DenseTensorTable); +REGISTER_PSCORE_CLASS(Table, GlobalStepTable); +REGISTER_PSCORE_CLASS(ValueAccessor, CommMergeAccessor); int32_t TableManager::initialize() { static bool initialized = false; @@ -61,9 +62,9 @@ int32_t Table::initialize_accessor() { << _config.table_id(); return -1; } - auto *accessor = - CREATE_CLASS(ValueAccessor, - _config.accessor().accessor_class()) if (accessor == NULL) { + auto *accessor = CREATE_PSCORE_CLASS( + ValueAccessor, + _config.accessor().accessor_class()) if (accessor == NULL) { LOG(ERROR) << "accessor is unregisteg, table_id:" << _config.table_id() << ", accessor_name:" << _config.accessor().accessor_class(); return -1; diff --git a/paddle/fluid/distributed/table/table.h b/paddle/fluid/distributed/table/table.h index 376d4a525b20decb43f2b2aa48e456716a603470..1bfedb53ab83d331d32b3ce828b0c1493c0ccc33 100644 --- a/paddle/fluid/distributed/table/table.h +++ b/paddle/fluid/distributed/table/table.h @@ -127,7 +127,7 @@ class Table { float *_global_lr = nullptr; std::shared_ptr _value_accesor; }; -REGISTER_REGISTERER(Table); +REGISTER_PSCORE_REGISTERER(Table); class TableManager { public: diff --git a/paddle/fluid/distributed/test/brpc_service_dense_sgd_test.cc b/paddle/fluid/distributed/test/brpc_service_dense_sgd_test.cc index a7af4c82897f1c97e4389f47936086fa8e145703..b793927e77f65b5ade406cd6a31f19040e448c55 100644 --- a/paddle/fluid/distributed/test/brpc_service_dense_sgd_test.cc +++ b/paddle/fluid/distributed/test/brpc_service_dense_sgd_test.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include + #include // NOLINT #include #include // NOLINT @@ -94,7 +95,7 @@ void GetDownpourDenseTableProto( server_proto->mutable_downpour_server_param(); ::paddle::distributed::ServerServiceParameter* server_service_proto = downpour_server_proto->mutable_service_param(); - server_service_proto->set_service_class("PsService"); + server_service_proto->set_service_class("BrpcPsService"); server_service_proto->set_server_class("BrpcPsServer"); server_service_proto->set_client_class("BrpcPsClient"); server_service_proto->set_start_server_port(0); @@ -124,7 +125,7 @@ void GetDownpourDenseTableProto( server_proto->mutable_downpour_server_param(); ::paddle::distributed::ServerServiceParameter* server_service_proto = downpour_server_proto->mutable_service_param(); - server_service_proto->set_service_class("PsService"); + server_service_proto->set_service_class("BrpcPsService"); server_service_proto->set_server_class("BrpcPsServer"); server_service_proto->set_client_class("BrpcPsClient"); server_service_proto->set_start_server_port(0); @@ -244,7 +245,8 @@ void RunBrpcPushDense() { int ret = 0; auto* closure = (paddle::distributed::DownpourBrpcClosure*)done; for (size_t i = 0; i < 1; ++i) { - if (closure->check_response(i, paddle::PS_PUSH_DENSE_TABLE) != 0) { + if (closure->check_response( + i, paddle::distributed::PS_PUSH_DENSE_TABLE) != 0) { ret = -1; break; } diff --git a/paddle/fluid/distributed/test/brpc_service_sparse_sgd_test.cc b/paddle/fluid/distributed/test/brpc_service_sparse_sgd_test.cc index 8cee608d5f76eb3daa88e93dbb5187171053d3f0..ddeb7b5023264f98932a34f0c3a1ba9004fe979f 100644 --- a/paddle/fluid/distributed/test/brpc_service_sparse_sgd_test.cc +++ b/paddle/fluid/distributed/test/brpc_service_sparse_sgd_test.cc @@ -18,6 +18,7 @@ limitations under the License. */ #include // NOLINT #include "google/protobuf/text_format.h" + #include "gtest/gtest.h" #include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/scope.h" @@ -94,7 +95,7 @@ void GetDownpourSparseTableProto( server_proto->mutable_downpour_server_param(); ::paddle::distributed::ServerServiceParameter* server_service_proto = downpour_server_proto->mutable_service_param(); - server_service_proto->set_service_class("PsService"); + server_service_proto->set_service_class("BrpcPsService"); server_service_proto->set_server_class("BrpcPsServer"); server_service_proto->set_client_class("BrpcPsClient"); server_service_proto->set_start_server_port(0); @@ -124,7 +125,7 @@ void GetDownpourSparseTableProto( server_proto->mutable_downpour_server_param(); ::paddle::distributed::ServerServiceParameter* server_service_proto = downpour_server_proto->mutable_service_param(); - server_service_proto->set_service_class("PsService"); + server_service_proto->set_service_class("BrpcPsService"); server_service_proto->set_server_class("BrpcPsServer"); server_service_proto->set_client_class("BrpcPsClient"); server_service_proto->set_start_server_port(0); @@ -225,7 +226,8 @@ void RunBrpcPushSparse() { int ret = 0; auto* closure = (paddle::distributed::DownpourBrpcClosure*)done; for (size_t i = 0; i < 1; ++i) { - if (closure->check_response(i, paddle::PS_PUSH_SPARSE_PARAM) != 0) { + if (closure->check_response( + i, paddle::distributed::PS_PUSH_SPARSE_PARAM) != 0) { ret = -1; break; } @@ -252,7 +254,8 @@ void RunBrpcPushSparse() { int ret = 0; auto* closure = (paddle::distributed::DownpourBrpcClosure*)done; for (size_t i = 0; i < 1; ++i) { - if (closure->check_response(i, paddle::PS_PUSH_SPARSE_TABLE) != 0) { + if (closure->check_response( + i, paddle::distributed::PS_PUSH_SPARSE_TABLE) != 0) { ret = -1; break; } diff --git a/paddle/fluid/distributed/test/brpc_utils_test.cc b/paddle/fluid/distributed/test/brpc_utils_test.cc index ce33cbe6ea39713589eb8f201d95bb2d99e5ff0c..531d995512f7c4965b098a7315ef81cb79070ca8 100644 --- a/paddle/fluid/distributed/test/brpc_utils_test.cc +++ b/paddle/fluid/distributed/test/brpc_utils_test.cc @@ -75,7 +75,7 @@ void RunMultiVarMsg(platform::Place place) { auto& ctx = *pool.Get(place); CreateVarsOnScope(&scope, &place, ctx); - ::paddle::MultiVariableMessage multi_msg; + ::paddle::distributed::MultiVariableMessage multi_msg; std::string message_name("se_de_test"); std::vector send_var_name = {"x1", "x2", "x3"}; std::vector recv_var_name = {}; @@ -138,4 +138,4 @@ TEST(MultiVarMsgCPU, Run) { // platform::CUDAPlace place; // RunMultiVarMsg(place); // } -// #endif \ No newline at end of file +// #endif diff --git a/paddle/fluid/framework/CMakeLists.txt b/paddle/fluid/framework/CMakeLists.txt index 00b17f6a109af018bba5359fcb1c6c5f1d0dd1c0..f96b9475f569039763de8e6efe6a80764f502c94 100644 --- a/paddle/fluid/framework/CMakeLists.txt +++ b/paddle/fluid/framework/CMakeLists.txt @@ -209,12 +209,12 @@ if(WITH_DISTRIBUTE) pull_dense_worker.cc section_worker.cc device_worker_factory.cc data_set.cc DEPS op_registry device_context scope framework_proto trainer_desc_proto glog fs shell fleet_wrapper heter_wrapper ps_gpu_wrapper box_wrapper lodtensor_printer - lod_rank_table feed_fetch_method sendrecvop_rpc communicator collective_helper ${GLOB_DISTRIBUTE_DEPS} + lod_rank_table feed_fetch_method collective_helper ${GLOB_DISTRIBUTE_DEPS} graph_to_program_pass variable_helper data_feed_proto timer monitor heter_service_proto pslib_brpc) set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor") set_source_files_properties(executor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) - else() + elseif(WITH_PSCORE) cc_library(executor SRCS executor.cc multi_trainer.cc pipeline_trainer.cc dataset_factory.cc dist_multi_trainer.cc trainer_factory.cc trainer.cc data_feed_factory.cc heterxpu_trainer.cc @@ -228,6 +228,16 @@ if(WITH_DISTRIBUTE) set_source_files_properties(executor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(multi_trainer.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(hogwild_worker.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) + else() + cc_library(executor SRCS executor.cc multi_trainer.cc pipeline_trainer.cc dataset_factory.cc + dist_multi_trainer.cc trainer_factory.cc trainer.cc data_feed_factory.cc + heterxpu_trainer.cc + data_feed.cc device_worker.cc hogwild_worker.cc hetercpu_worker.cc ps_gpu_worker.cc + heterbox_worker.cc heterbox_trainer.cc ps_gpu_trainer.cc downpour_worker.cc downpour_worker_opt.cc + pull_dense_worker.cc section_worker.cc device_worker_factory.cc data_set.cc DEPS op_registry + device_context scope framework_proto data_feed_proto heter_service_proto trainer_desc_proto glog + lod_rank_table fs shell fleet_wrapper heter_wrapper ps_gpu_wrapper box_wrapper lodtensor_printer feed_fetch_method + graph_to_program_pass variable_helper timer monitor) endif() elseif(WITH_PSLIB) cc_library(executor SRCS executor.cc multi_trainer.cc pipeline_trainer.cc dataset_factory.cc @@ -239,7 +249,6 @@ elseif(WITH_PSLIB) device_context scope framework_proto data_feed_proto heter_service_proto trainer_desc_proto glog lod_rank_table fs shell fleet_wrapper heter_wrapper ps_gpu_wrapper box_wrapper lodtensor_printer feed_fetch_method graph_to_program_pass variable_helper timer monitor pslib_brpc ) - else() cc_library(executor SRCS executor.cc multi_trainer.cc pipeline_trainer.cc dataset_factory.cc dist_multi_trainer.cc trainer_factory.cc trainer.cc data_feed_factory.cc diff --git a/paddle/fluid/framework/details/CMakeLists.txt b/paddle/fluid/framework/details/CMakeLists.txt index b38abde25401df932a73cfa21fe52fcd586c5755..0c9e30fd195193323bb13470d2fbf9075779c939 100644 --- a/paddle/fluid/framework/details/CMakeLists.txt +++ b/paddle/fluid/framework/details/CMakeLists.txt @@ -14,7 +14,7 @@ cc_library(multi_devices_helper SRCS multi_devices_helper.cc DEPS graph graph_he cc_library(variable_visitor SRCS variable_visitor.cc DEPS lod_tensor selected_rows) -if(WITH_DISTRIBUTE) +if(WITH_PSCORE) set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor") set_source_files_properties(reduce_op_handle.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(threaded_ssa_graph_executor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) diff --git a/paddle/fluid/framework/details/async_ssa_graph_executor.cc b/paddle/fluid/framework/details/async_ssa_graph_executor.cc index 679ace135b699fc5fab8aabb798b2edbe8343435..b8fac755709e768e8d413a8f91cbc0ae45d27195 100644 --- a/paddle/fluid/framework/details/async_ssa_graph_executor.cc +++ b/paddle/fluid/framework/details/async_ssa_graph_executor.cc @@ -16,7 +16,7 @@ #include "paddle/fluid/framework/variable_helper.h" -#ifdef PADDLE_WITH_DISTRIBUTE +#if defined PADDLE_WITH_PSCORE #include "paddle/fluid/distributed/service/communicator.h" #endif @@ -138,7 +138,7 @@ FetchResultType AsyncSSAGraphExecutor::Run( "results to be fetched!")); // init once if (run_futures_.size() == 0 && places_.size() > 1) { -#ifdef PADDLE_WITH_DISTRIBUTE +#if defined PADDLE_WITH_PSCORE if (strategy_.thread_barrier_) { paddle::distributed::Communicator::GetInstance()->BarrierTriggerReset( places_.size()); diff --git a/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc b/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc index 00201bd442e3b968113c6c7c351f257300fcbbdb..265e346a9d8dfb0925783b812174410bb11ae86d 100644 --- a/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc +++ b/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc @@ -17,7 +17,7 @@ #include "paddle/fluid/framework/ir/graph_helper.h" #include "paddle/fluid/platform/profiler.h" -#ifdef PADDLE_WITH_DISTRIBUTE +#if defined PADDLE_WITH_PSCORE #include "paddle/fluid/distributed/service/communicator.h" #endif @@ -360,7 +360,7 @@ bool ThreadedSSAGraphExecutor::RunOpSync(OpHandleBase *op) { void ThreadedSSAGraphExecutor::ExecutionFinal( std::vector *fetch_ops) { -#ifdef PADDLE_WITH_DISTRIBUTE +#if defined PADDLE_WITH_PSCORE if (strategy_.thread_barrier_) { paddle::distributed::Communicator::GetInstance()->BarrierTriggerDecrement(); } diff --git a/paddle/fluid/framework/hogwild_worker.cc b/paddle/fluid/framework/hogwild_worker.cc index a7f09723f152d48f59a8db738569e9b786939a8b..7aaaba510469dac7affedad78a768d2bfb68640f 100644 --- a/paddle/fluid/framework/hogwild_worker.cc +++ b/paddle/fluid/framework/hogwild_worker.cc @@ -19,7 +19,7 @@ limitations under the License. */ #include "paddle/fluid/platform/cpu_helper.h" #include "paddle/fluid/platform/lodtensor_printer.h" -#ifdef PADDLE_WITH_DISTRIBUTE +#if defined PADDLE_WITH_PSCORE #include "paddle/fluid/distributed/service/communicator.h" #endif @@ -186,7 +186,7 @@ void HogwildWorker::TrainFilesWithProfiler() { writer_.Flush(); } -#ifdef PADDLE_WITH_DISTRIBUTE +#if defined PADDLE_WITH_PSCORE if (thread_barrier_) { paddle::distributed::Communicator::GetInstance()->BarrierTriggerDecrement(); } @@ -216,7 +216,7 @@ void HogwildWorker::TrainFiles() { PrintFetchVars(); thread_scope_->DropKids(); } -#ifdef PADDLE_WITH_DISTRIBUTE +#if defined PADDLE_WITH_PSCORE if (thread_barrier_) { paddle::distributed::Communicator::GetInstance()->BarrierTriggerDecrement(); } diff --git a/paddle/fluid/framework/multi_trainer.cc b/paddle/fluid/framework/multi_trainer.cc index 216cf06f32fdd47b3e5bb89dde17fb5d300ffd41..2c72fa45656d78b01d068b466ffbce345a265b6d 100644 --- a/paddle/fluid/framework/multi_trainer.cc +++ b/paddle/fluid/framework/multi_trainer.cc @@ -18,7 +18,7 @@ limitations under the License. */ #include "paddle/fluid/framework/device_worker_factory.h" #include "paddle/fluid/framework/trainer.h" -#ifdef PADDLE_WITH_DISTRIBUTE +#if defined PADDLE_WITH_PSCORE #include "paddle/fluid/distributed/service/communicator.h" #endif @@ -49,7 +49,7 @@ void MultiTrainer::Initialize(const TrainerDesc& trainer_desc, VLOG(3) << "worker thread num: " << thread_num_; workers_.resize(thread_num_); -#ifdef PADDLE_WITH_DISTRIBUTE +#if defined PADDLE_WITH_PSCORE if (trainer_desc.thread_barrier()) { paddle::distributed::Communicator::GetInstance()->BarrierTriggerReset( thread_num_); diff --git a/paddle/fluid/inference/CMakeLists.txt b/paddle/fluid/inference/CMakeLists.txt index 5207b89e2987cc284c4fd64f03cc817751e5b55d..1896be4f9216b59a2c743616c776fc5f62482dd0 100644 --- a/paddle/fluid/inference/CMakeLists.txt +++ b/paddle/fluid/inference/CMakeLists.txt @@ -77,12 +77,12 @@ set(SHARED_INFERENCE_SRCS ${mkldnn_quantizer_src_file}) # Create shared inference library defaultly -if(NOT WITH_DISTRIBUTE) +if(NOT WITH_PSCORE) cc_library(paddle_fluid_shared SHARED SRCS ${SHARED_INFERENCE_SRCS} - DEPS ${fluid_modules} analysis_predictor) + DEPS ${fluid_modules} analysis_predictor) else() cc_library(paddle_fluid_shared SHARED SRCS ${SHARED_INFERENCE_SRCS} - DEPS ${fluid_modules} analysis_predictor fleet ps_service) + DEPS ${fluid_modules} analysis_predictor fleet ps_service) endif() get_property(os_dependency_modules GLOBAL PROPERTY OS_DEPENDENCY_MODULES) diff --git a/paddle/fluid/operators/CMakeLists.txt b/paddle/fluid/operators/CMakeLists.txt index c8f07d8b4647823068ab0349e03ddd6e160ec99a..28741ce94718fd6b9752cd2561db52fef4204602 100644 --- a/paddle/fluid/operators/CMakeLists.txt +++ b/paddle/fluid/operators/CMakeLists.txt @@ -22,10 +22,13 @@ add_subdirectory(jit) if(WITH_DISTRIBUTE) - add_subdirectory(pscore) add_subdirectory(collective) endif() +if (WITH_PSCORE) + add_subdirectory(pscore) +endif() + add_subdirectory(amp) add_subdirectory(reader) diff --git a/paddle/fluid/operators/pscore/CMakeLists.txt b/paddle/fluid/operators/pscore/CMakeLists.txt index 7688f0e2a964021672ce7910c1b88b5282a8b15f..3e388b8d5ea104b99cc2c55cb130c01afa77fde4 100644 --- a/paddle/fluid/operators/pscore/CMakeLists.txt +++ b/paddle/fluid/operators/pscore/CMakeLists.txt @@ -1,3 +1,7 @@ +if (WITH_PSLIB) + return() +endif() + include(operators) set(DISTRIBUTE_DEPS "") diff --git a/paddle/fluid/operators/pscore/heter_listen_and_serv_op.h b/paddle/fluid/operators/pscore/heter_listen_and_serv_op.h index 33a287ad90ed4161c2737174234c2e88dfb1cb0e..4985d033e2da6ec19bbb1a2972f0c18fbc69cade 100644 --- a/paddle/fluid/operators/pscore/heter_listen_and_serv_op.h +++ b/paddle/fluid/operators/pscore/heter_listen_and_serv_op.h @@ -46,8 +46,8 @@ class DeviceContext; namespace paddle { namespace operators { -using MultiVarMsg = ::paddle::MultiVariableMessage; -using VarMsg = ::paddle::VariableMessage; +using MultiVarMsg = ::paddle::distributed::MultiVariableMessage; +using VarMsg = ::paddle::distributed::VariableMessage; template class DoubleFindMap : public std::unordered_map { diff --git a/paddle/fluid/operators/pscore/heter_listen_and_server_test.cc b/paddle/fluid/operators/pscore/heter_listen_and_server_test.cc index 2393a61dc0f1997d95939f5a4be4f1bf8d69d5a2..767856ccde9c5152de6f74d3f0e4333eca57017a 100644 --- a/paddle/fluid/operators/pscore/heter_listen_and_server_test.cc +++ b/paddle/fluid/operators/pscore/heter_listen_and_server_test.cc @@ -36,8 +36,8 @@ namespace framework = paddle::framework; namespace platform = paddle::platform; namespace distributed = paddle::distributed; -using MultiVarMsg = ::paddle::MultiVariableMessage; -using VarMsg = ::paddle::VariableMessage; +using MultiVarMsg = ::paddle::distributed::MultiVariableMessage; +using VarMsg = ::paddle::distributed::VariableMessage; DECLARE_double(eager_delete_tensor_gb); USE_OP(scale); diff --git a/paddle/fluid/operators/pscore/heter_server_test.cc b/paddle/fluid/operators/pscore/heter_server_test.cc index d95988719d5f8cd3fd84563fd2e66cdd50518976..02832ca72df400e0961997887299aa67cf799f29 100644 --- a/paddle/fluid/operators/pscore/heter_server_test.cc +++ b/paddle/fluid/operators/pscore/heter_server_test.cc @@ -32,8 +32,8 @@ namespace framework = paddle::framework; namespace platform = paddle::platform; namespace distributed = paddle::distributed; -using MultiVarMsg = ::paddle::MultiVariableMessage; -using VarMsg = ::paddle::VariableMessage; +using MultiVarMsg = ::paddle::distributed::MultiVariableMessage; +using VarMsg = ::paddle::distributed::VariableMessage; USE_OP(scale); diff --git a/paddle/fluid/pybind/CMakeLists.txt b/paddle/fluid/pybind/CMakeLists.txt index 93c42e692c4f5d3dc02537799dc74c51e24ace0c..1e4bf43f62ed4fdc096ccf53caa0e229a0bec110 100644 --- a/paddle/fluid/pybind/CMakeLists.txt +++ b/paddle/fluid/pybind/CMakeLists.txt @@ -49,7 +49,7 @@ if (WITH_CRYPTO) set(PYBIND_SRCS ${PYBIND_SRCS} crypto.cc) endif (WITH_CRYPTO) -if (WITH_DISTRIBUTE) +if (WITH_PSCORE) set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor -Wno-error=sign-compare -Wno-error=unused-variable -Wno-error=return-type -Wno-error=unused-but-set-variable -Wno-error=type-limits -Wno-error=unknown-pragmas -Wno-error=parentheses -Wno-error=unused-result") set_source_files_properties(fleet_py.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) list(APPEND PYBIND_DEPS fleet communicator) diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index 5f4c5fd2c30a453eca30d431d4a3fcdcf70da5b4..b66dd17bbcd2bc91b5131df366033558106dd60f 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -107,7 +107,7 @@ limitations under the License. */ #include "paddle/fluid/pybind/crypto.h" #endif -#ifdef PADDLE_WITH_DISTRIBUTE +#if defined PADDLE_WITH_PSCORE #include "paddle/fluid/pybind/fleet_py.h" #endif @@ -2841,7 +2841,7 @@ All parameter, weight, gradient are variables in Paddle. BindCrypto(&m); #endif -#ifdef PADDLE_WITH_DISTRIBUTE +#if defined PADDLE_WITH_PSCORE BindDistFleetWrapper(&m); BindPSHost(&m); BindCommunicatorContext(&m); diff --git a/paddle/scripts/paddle_build.sh b/paddle/scripts/paddle_build.sh index 0f5d2d3bc2bbb5351ae9d9012748918c4b17ab1c..fc4de4565b8e49bce96af6912a98a1dbd2ecf179 100755 --- a/paddle/scripts/paddle_build.sh +++ b/paddle/scripts/paddle_build.sh @@ -249,6 +249,7 @@ function cmake_base() { -DPY_VERSION=${PY_VERSION:-2.7} -DCMAKE_INSTALL_PREFIX=${INSTALL_PREFIX:-/paddle/build} -DWITH_GRPC=${grpc_flag} + -DWITH_PSCORE=${distibuted_flag} -DWITH_GLOO=${gloo_flag} -DWITH_LITE=${WITH_LITE:-OFF} -DWITH_XPU=${WITH_XPU:-OFF} @@ -284,6 +285,7 @@ EOF -DPY_VERSION=${PY_VERSION:-2.7} \ -DCMAKE_INSTALL_PREFIX=${INSTALL_PREFIX:-/paddle/build} \ -DWITH_GRPC=${grpc_flag} \ + -DWITH_PSCORE=${distibuted_flag} \ -DWITH_GLOO=${gloo_flag} \ -DLITE_GIT_TAG=develop \ -DWITH_XPU=${WITH_XPU:-OFF} \ diff --git a/paddle/testing/paddle_gtest_main.cc b/paddle/testing/paddle_gtest_main.cc index eb038fb98d60c9830043371db7de06ff8b349063..fab6eea49bff0f28954c9c8ea1b728d07982786c 100644 --- a/paddle/testing/paddle_gtest_main.cc +++ b/paddle/testing/paddle_gtest_main.cc @@ -59,7 +59,8 @@ int main(int argc, char** argv) { std::vector envs; std::vector undefok; -#if defined(PADDLE_WITH_DISTRIBUTE) && !defined(PADDLE_WITH_GRPC) +#if defined(PADDLE_WITH_DISTRIBUTE) && !defined(PADDLE_WITH_GRPC) && \ + !defined(PADDLE_WITH_PSLIB) std::string str_max_body_size; if (google::GetCommandLineOption("max_body_size", &str_max_body_size)) { setenv("FLAGS_max_body_size", "2147483647", 1); diff --git a/python/paddle/distributed/fleet/runtime/the_one_ps.py b/python/paddle/distributed/fleet/runtime/the_one_ps.py index 74a961eff0297fc8a4c320a9795158871e85d7c8..37d79abbab08ea3866141c2636d47dad01ed4830 100644 --- a/python/paddle/distributed/fleet/runtime/the_one_ps.py +++ b/python/paddle/distributed/fleet/runtime/the_one_ps.py @@ -268,7 +268,7 @@ class Service: def __init__(self): self.server_class = "BrpcPsServer" self.client_class = "BrpcPsClient" - self.service_class = "PsService" + self.service_class = "BrpcPsService" self.start_server_port = 0 self.server_thread_num = 12