未验证 提交 a97ca56a 编写于 作者: T tangwei12 提交者: GitHub

split ps with distributed (#30337)

Change-Id: I3c788e7576688e63181e7f01562529b85a09cc59
上级 5eab1a38
...@@ -136,6 +136,7 @@ option(WITH_BOX_PS "Compile with box_ps support" OFF) ...@@ -136,6 +136,7 @@ option(WITH_BOX_PS "Compile with box_ps support" OFF)
option(WITH_XBYAK "Compile with xbyak support" ON) option(WITH_XBYAK "Compile with xbyak support" ON)
option(WITH_CONTRIB "Compile the third-party contributation" OFF) option(WITH_CONTRIB "Compile the third-party contributation" OFF)
option(WITH_GRPC "Use grpc as the default rpc framework" ${WITH_DISTRIBUTE}) option(WITH_GRPC "Use grpc as the default rpc framework" ${WITH_DISTRIBUTE})
option(WITH_PSCORE "Compile with parameter server support" ${WITH_DISTRIBUTE})
option(WITH_INFERENCE_API_TEST "Test fluid inference C++ high-level api interface" OFF) option(WITH_INFERENCE_API_TEST "Test fluid inference C++ high-level api interface" OFF)
option(PY_VERSION "Compile PaddlePaddle with python3 support" ${PY_VERSION}) option(PY_VERSION "Compile PaddlePaddle with python3 support" ${PY_VERSION})
option(WITH_DGC "Use DGC(Deep Gradient Compression) or not" ${WITH_DISTRIBUTE}) option(WITH_DGC "Use DGC(Deep Gradient Compression) or not" ${WITH_DISTRIBUTE})
......
...@@ -156,6 +156,11 @@ if(WITH_DISTRIBUTE) ...@@ -156,6 +156,11 @@ if(WITH_DISTRIBUTE)
add_definitions(-DPADDLE_WITH_DISTRIBUTE) add_definitions(-DPADDLE_WITH_DISTRIBUTE)
endif() endif()
if(WITH_PSCORE)
add_definitions(-DPADDLE_WITH_PSCORE)
endif()
if(WITH_GRPC) if(WITH_GRPC)
add_definitions(-DPADDLE_WITH_GRPC) add_definitions(-DPADDLE_WITH_GRPC)
endif(WITH_GRPC) endif(WITH_GRPC)
......
...@@ -280,7 +280,7 @@ if(WITH_BOX_PS) ...@@ -280,7 +280,7 @@ if(WITH_BOX_PS)
list(APPEND third_party_deps extern_box_ps) list(APPEND third_party_deps extern_box_ps)
endif(WITH_BOX_PS) endif(WITH_BOX_PS)
if (WITH_DISTRIBUTE) if (WITH_PSCORE)
include(external/snappy) include(external/snappy)
list(APPEND third_party_deps extern_snappy) list(APPEND third_party_deps extern_snappy)
......
if (WITH_PSLIB) if(NOT WITH_PSCORE)
return()
endif()
if(NOT WITH_DISTRIBUTE)
return() return()
endif() endif()
......
...@@ -69,24 +69,24 @@ class ObjectFactory { ...@@ -69,24 +69,24 @@ class ObjectFactory {
}; };
typedef std::map<std::string, ObjectFactory *> FactoryMap; typedef std::map<std::string, ObjectFactory *> FactoryMap;
typedef std::map<std::string, FactoryMap> BaseClassMap; typedef std::map<std::string, FactoryMap> PsCoreClassMap;
#ifdef __cplusplus #ifdef __cplusplus
extern "C" { extern "C" {
#endif #endif
inline BaseClassMap &global_factory_map() { inline PsCoreClassMap &global_factory_map() {
static BaseClassMap *base_class = new BaseClassMap(); static PsCoreClassMap *base_class = new PsCoreClassMap();
return *base_class; return *base_class;
} }
#ifdef __cplusplus #ifdef __cplusplus
} }
#endif #endif
inline BaseClassMap &global_factory_map_cpp() { return global_factory_map(); } inline PsCoreClassMap &global_factory_map_cpp() { return global_factory_map(); }
// typedef pa::Any Any; // typedef pa::Any Any;
// typedef ::FactoryMap FactoryMap; // typedef ::FactoryMap FactoryMap;
#define REGISTER_REGISTERER(base_class) \ #define REGISTER_PSCORE_REGISTERER(base_class) \
class base_class##Registerer { \ class base_class##Registerer { \
public: \ public: \
static base_class *CreateInstanceByName(const ::std::string &name) { \ static base_class *CreateInstanceByName(const ::std::string &name) { \
...@@ -107,7 +107,7 @@ inline BaseClassMap &global_factory_map_cpp() { return global_factory_map(); } ...@@ -107,7 +107,7 @@ inline BaseClassMap &global_factory_map_cpp() { return global_factory_map(); }
} \ } \
}; };
#define REGISTER_CLASS(clazz, name) \ #define REGISTER_PSCORE_CLASS(clazz, name) \
class ObjectFactory##name : public ObjectFactory { \ class ObjectFactory##name : public ObjectFactory { \
public: \ public: \
Any NewInstance() { return Any(new name()); } \ Any NewInstance() { return Any(new name()); } \
...@@ -120,7 +120,7 @@ inline BaseClassMap &global_factory_map_cpp() { return global_factory_map(); } ...@@ -120,7 +120,7 @@ inline BaseClassMap &global_factory_map_cpp() { return global_factory_map(); }
} \ } \
void register_factory_##name() __attribute__((constructor)); void register_factory_##name() __attribute__((constructor));
#define CREATE_CLASS(base_class, name) \ #define CREATE_PSCORE_CLASS(base_class, name) \
base_class##Registerer::CreateInstanceByName(name); base_class##Registerer::CreateInstanceByName(name);
} // namespace distributed } // namespace distributed
......
...@@ -86,7 +86,7 @@ message SparseTableParameter { ...@@ -86,7 +86,7 @@ message SparseTableParameter {
message ServerServiceParameter { message ServerServiceParameter {
optional string server_class = 1 [ default = "BrpcPsServer" ]; optional string server_class = 1 [ default = "BrpcPsServer" ];
optional string client_class = 2 [ default = "BrpcPsClient" ]; optional string client_class = 2 [ default = "BrpcPsClient" ];
optional string service_class = 3 [ default = "PsService" ]; optional string service_class = 3 [ default = "BrpcPsService" ];
optional uint32 start_server_port = 4 optional uint32 start_server_port = 4
[ default = 0 ]; // will find a avaliable port from it [ default = 0 ]; // will find a avaliable port from it
optional uint32 server_thread_num = 5 [ default = 12 ]; optional uint32 server_thread_num = 5 [ default = 12 ];
......
...@@ -17,8 +17,8 @@ ...@@ -17,8 +17,8 @@
#include <sstream> #include <sstream>
#include <string> #include <string>
#include <vector> #include <vector>
#include "Eigen/Dense" #include "Eigen/Dense"
#include "paddle/fluid/distributed/service/brpc_ps_client.h" #include "paddle/fluid/distributed/service/brpc_ps_client.h"
#include "paddle/fluid/distributed/table/table.h" #include "paddle/fluid/distributed/table/table.h"
#include "paddle/fluid/framework/archive.h" #include "paddle/fluid/framework/archive.h"
...@@ -80,8 +80,8 @@ inline size_t get_sparse_shard(uint32_t shard_num, uint32_t server_num, ...@@ -80,8 +80,8 @@ inline size_t get_sparse_shard(uint32_t shard_num, uint32_t server_num,
void DownpourPsClientService::service( void DownpourPsClientService::service(
::google::protobuf::RpcController *controller, ::google::protobuf::RpcController *controller,
const ::paddle::PsRequestMessage *request, const PsRequestMessage *request, PsResponseMessage *response,
::paddle::PsResponseMessage *response, ::google::protobuf::Closure *done) { ::google::protobuf::Closure *done) {
brpc::ClosureGuard done_guard(done); brpc::ClosureGuard done_guard(done);
int ret = _client->handle_client2client_msg( int ret = _client->handle_client2client_msg(
request->cmd_id(), request->client_id(), request->data()); request->cmd_id(), request->client_id(), request->data());
......
...@@ -40,8 +40,8 @@ class DownpourPsClientService : public PsService { ...@@ -40,8 +40,8 @@ class DownpourPsClientService : public PsService {
return 0; return 0;
} }
virtual void service(::google::protobuf::RpcController *controller, virtual void service(::google::protobuf::RpcController *controller,
const ::paddle::PsRequestMessage *request, const PsRequestMessage *request,
::paddle::PsResponseMessage *response, PsResponseMessage *response,
::google::protobuf::Closure *done) override; ::google::protobuf::Closure *done) override;
protected: protected:
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
// limitations under the License. // limitations under the License.
#include "paddle/fluid/distributed/service/brpc_ps_server.h" #include "paddle/fluid/distributed/service/brpc_ps_server.h"
#include <thread> // NOLINT #include <thread> // NOLINT
#include "Eigen/Dense" #include "Eigen/Dense"
#include "butil/endpoint.h" #include "butil/endpoint.h"
...@@ -30,7 +31,8 @@ int32_t BrpcPsServer::initialize() { ...@@ -30,7 +31,8 @@ int32_t BrpcPsServer::initialize() {
LOG(ERROR) << "miss service_class in ServerServiceParameter"; LOG(ERROR) << "miss service_class in ServerServiceParameter";
return -1; return -1;
} }
auto *service = CREATE_CLASS(PsBaseService, service_config.service_class()); auto *service =
CREATE_PSCORE_CLASS(PsBaseService, service_config.service_class());
if (service == NULL) { if (service == NULL) {
LOG(ERROR) << "service is unregistered, service_name:" LOG(ERROR) << "service is unregistered, service_name:"
<< service_config.service_class(); << service_config.service_class();
...@@ -79,28 +81,28 @@ uint64_t BrpcPsServer::start(const std::string &ip, uint32_t port) { ...@@ -79,28 +81,28 @@ uint64_t BrpcPsServer::start(const std::string &ip, uint32_t port) {
int32_t BrpcPsServer::port() { return _server.listen_address().port; } int32_t BrpcPsServer::port() { return _server.listen_address().port; }
int32_t PsService::initialize() { int32_t BrpcPsService::initialize() {
_is_initialize_shard_info = false; _is_initialize_shard_info = false;
_service_handler_map[PS_STOP_SERVER] = &PsService::stop_server; _service_handler_map[PS_STOP_SERVER] = &BrpcPsService::stop_server;
_service_handler_map[PS_PULL_DENSE_TABLE] = &PsService::pull_dense; _service_handler_map[PS_PULL_DENSE_TABLE] = &BrpcPsService::pull_dense;
_service_handler_map[PS_PUSH_DENSE_TABLE] = &PsService::push_dense; _service_handler_map[PS_PUSH_DENSE_TABLE] = &BrpcPsService::push_dense;
_service_handler_map[PS_PULL_SPARSE_TABLE] = &PsService::pull_sparse; _service_handler_map[PS_PULL_SPARSE_TABLE] = &BrpcPsService::pull_sparse;
_service_handler_map[PS_PUSH_SPARSE_TABLE] = &PsService::push_sparse; _service_handler_map[PS_PUSH_SPARSE_TABLE] = &BrpcPsService::push_sparse;
_service_handler_map[PS_SAVE_ONE_TABLE] = &PsService::save_one_table; _service_handler_map[PS_SAVE_ONE_TABLE] = &BrpcPsService::save_one_table;
_service_handler_map[PS_SAVE_ALL_TABLE] = &PsService::save_all_table; _service_handler_map[PS_SAVE_ALL_TABLE] = &BrpcPsService::save_all_table;
_service_handler_map[PS_SHRINK_TABLE] = &PsService::shrink_table; _service_handler_map[PS_SHRINK_TABLE] = &BrpcPsService::shrink_table;
_service_handler_map[PS_LOAD_ONE_TABLE] = &PsService::load_one_table; _service_handler_map[PS_LOAD_ONE_TABLE] = &BrpcPsService::load_one_table;
_service_handler_map[PS_LOAD_ALL_TABLE] = &PsService::load_all_table; _service_handler_map[PS_LOAD_ALL_TABLE] = &BrpcPsService::load_all_table;
_service_handler_map[PS_CLEAR_ONE_TABLE] = &PsService::clear_one_table; _service_handler_map[PS_CLEAR_ONE_TABLE] = &BrpcPsService::clear_one_table;
_service_handler_map[PS_CLEAR_ALL_TABLE] = &PsService::clear_all_table; _service_handler_map[PS_CLEAR_ALL_TABLE] = &BrpcPsService::clear_all_table;
_service_handler_map[PS_PUSH_DENSE_PARAM] = &PsService::push_dense_param; _service_handler_map[PS_PUSH_DENSE_PARAM] = &BrpcPsService::push_dense_param;
_service_handler_map[PS_PRINT_TABLE_STAT] = &PsService::print_table_stat; _service_handler_map[PS_PRINT_TABLE_STAT] = &BrpcPsService::print_table_stat;
_service_handler_map[PS_PULL_GEO_PARAM] = &PsService::pull_geo_param; _service_handler_map[PS_PULL_GEO_PARAM] = &BrpcPsService::pull_geo_param;
_service_handler_map[PS_PUSH_SPARSE_PARAM] = &PsService::push_sparse_param; _service_handler_map[PS_PUSH_SPARSE_PARAM] =
_service_handler_map[PS_BARRIER] = &PsService::barrier; &BrpcPsService::push_sparse_param;
_service_handler_map[PS_START_PROFILER] = &PsService::start_profiler; _service_handler_map[PS_BARRIER] = &BrpcPsService::barrier;
_service_handler_map[PS_STOP_PROFILER] = &PsService::stop_profiler; _service_handler_map[PS_START_PROFILER] = &BrpcPsService::start_profiler;
_service_handler_map[PS_PUSH_GLOBAL_STEP] = &PsService::push_global_step; _service_handler_map[PS_STOP_PROFILER] = &BrpcPsService::stop_profiler;
// shard初始化,server启动后才可从env获取到server_list的shard信息 // shard初始化,server启动后才可从env获取到server_list的shard信息
initialize_shard_info(); initialize_shard_info();
...@@ -116,7 +118,7 @@ int32_t PsService::initialize() { ...@@ -116,7 +118,7 @@ int32_t PsService::initialize() {
return -1; \ return -1; \
} }
int32_t PsService::initialize_shard_info() { int32_t BrpcPsService::initialize_shard_info() {
if (!_is_initialize_shard_info) { if (!_is_initialize_shard_info) {
std::lock_guard<std::mutex> guard(_initialize_shard_mutex); std::lock_guard<std::mutex> guard(_initialize_shard_mutex);
if (_is_initialize_shard_info) { if (_is_initialize_shard_info) {
...@@ -132,10 +134,10 @@ int32_t PsService::initialize_shard_info() { ...@@ -132,10 +134,10 @@ int32_t PsService::initialize_shard_info() {
return 0; return 0;
} }
void PsService::service(google::protobuf::RpcController *cntl_base, void BrpcPsService::service(google::protobuf::RpcController *cntl_base,
const PsRequestMessage *request, const PsRequestMessage *request,
PsResponseMessage *response, PsResponseMessage *response,
google::protobuf::Closure *done) { google::protobuf::Closure *done) {
brpc::ClosureGuard done_guard(done); brpc::ClosureGuard done_guard(done);
std::string log_label("ReceiveCmd-"); std::string log_label("ReceiveCmd-");
if (!request->has_table_id()) { if (!request->has_table_id()) {
...@@ -163,9 +165,9 @@ void PsService::service(google::protobuf::RpcController *cntl_base, ...@@ -163,9 +165,9 @@ void PsService::service(google::protobuf::RpcController *cntl_base,
} }
} }
int32_t PsService::pull_dense(Table *table, const PsRequestMessage &request, int32_t BrpcPsService::pull_dense(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, PsResponseMessage &response,
brpc::Controller *cntl) { brpc::Controller *cntl) {
platform::RecordEvent record_event("PsService->pull_dense"); platform::RecordEvent record_event("PsService->pull_dense");
CHECK_TABLE_EXIST(table, request, response) CHECK_TABLE_EXIST(table, request, response)
if (request.params_size() < 1) { if (request.params_size() < 1) {
...@@ -191,10 +193,10 @@ int32_t PsService::pull_dense(Table *table, const PsRequestMessage &request, ...@@ -191,10 +193,10 @@ int32_t PsService::pull_dense(Table *table, const PsRequestMessage &request,
return 0; return 0;
} }
int32_t PsService::push_dense_param(Table *table, int32_t BrpcPsService::push_dense_param(Table *table,
const PsRequestMessage &request, const PsRequestMessage &request,
PsResponseMessage &response, PsResponseMessage &response,
brpc::Controller *cntl) { brpc::Controller *cntl) {
platform::RecordEvent record_event("PsService->push_dense_param"); platform::RecordEvent record_event("PsService->push_dense_param");
CHECK_TABLE_EXIST(table, request, response) CHECK_TABLE_EXIST(table, request, response)
thread_local std::string push_buffer; thread_local std::string push_buffer;
...@@ -218,9 +220,9 @@ int32_t PsService::push_dense_param(Table *table, ...@@ -218,9 +220,9 @@ int32_t PsService::push_dense_param(Table *table,
return 0; return 0;
} }
int32_t PsService::push_dense(Table *table, const PsRequestMessage &request, int32_t BrpcPsService::push_dense(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, PsResponseMessage &response,
brpc::Controller *cntl) { brpc::Controller *cntl) {
platform::RecordEvent record_event("PsService->push_dense"); platform::RecordEvent record_event("PsService->push_dense");
CHECK_TABLE_EXIST(table, request, response) CHECK_TABLE_EXIST(table, request, response)
auto req_buffer_size = request.data().size(); auto req_buffer_size = request.data().size();
...@@ -244,9 +246,9 @@ int32_t PsService::push_dense(Table *table, const PsRequestMessage &request, ...@@ -244,9 +246,9 @@ int32_t PsService::push_dense(Table *table, const PsRequestMessage &request,
return 0; return 0;
} }
int32_t PsService::barrier(Table *table, const PsRequestMessage &request, int32_t BrpcPsService::barrier(Table *table, const PsRequestMessage &request,
PsResponseMessage &response, PsResponseMessage &response,
brpc::Controller *cntl) { brpc::Controller *cntl) {
CHECK_TABLE_EXIST(table, request, response) CHECK_TABLE_EXIST(table, request, response)
if (request.params_size() < 1) { if (request.params_size() < 1) {
...@@ -262,10 +264,10 @@ int32_t PsService::barrier(Table *table, const PsRequestMessage &request, ...@@ -262,10 +264,10 @@ int32_t PsService::barrier(Table *table, const PsRequestMessage &request,
return 0; return 0;
} }
int32_t PsService::push_sparse_param(Table *table, int32_t BrpcPsService::push_sparse_param(Table *table,
const PsRequestMessage &request, const PsRequestMessage &request,
PsResponseMessage &response, PsResponseMessage &response,
brpc::Controller *cntl) { brpc::Controller *cntl) {
platform::RecordEvent record_event("PsService->push_sparse_param"); platform::RecordEvent record_event("PsService->push_sparse_param");
CHECK_TABLE_EXIST(table, request, response) CHECK_TABLE_EXIST(table, request, response)
auto &push_data = request.data(); auto &push_data = request.data();
...@@ -294,9 +296,10 @@ int32_t PsService::push_sparse_param(Table *table, ...@@ -294,9 +296,10 @@ int32_t PsService::push_sparse_param(Table *table,
return 0; return 0;
} }
int32_t PsService::pull_geo_param(Table *table, const PsRequestMessage &request, int32_t BrpcPsService::pull_geo_param(Table *table,
PsResponseMessage &response, const PsRequestMessage &request,
brpc::Controller *cntl) { PsResponseMessage &response,
brpc::Controller *cntl) {
platform::RecordEvent record_event("PsService->pull_geo_param"); platform::RecordEvent record_event("PsService->pull_geo_param");
CHECK_TABLE_EXIST(table, request, response) CHECK_TABLE_EXIST(table, request, response)
thread_local std::string push_sparse_request_buffer; thread_local std::string push_sparse_request_buffer;
...@@ -316,9 +319,10 @@ int32_t PsService::pull_geo_param(Table *table, const PsRequestMessage &request, ...@@ -316,9 +319,10 @@ int32_t PsService::pull_geo_param(Table *table, const PsRequestMessage &request,
return 0; return 0;
} }
int32_t PsService::pull_sparse(Table *table, const PsRequestMessage &request, int32_t BrpcPsService::pull_sparse(Table *table,
PsResponseMessage &response, const PsRequestMessage &request,
brpc::Controller *cntl) { PsResponseMessage &response,
brpc::Controller *cntl) {
platform::RecordEvent record_event("PsService->pull_sparse"); platform::RecordEvent record_event("PsService->pull_sparse");
CHECK_TABLE_EXIST(table, request, response) CHECK_TABLE_EXIST(table, request, response)
thread_local std::string push_sparse_request_buffer; thread_local std::string push_sparse_request_buffer;
...@@ -353,9 +357,10 @@ int32_t PsService::pull_sparse(Table *table, const PsRequestMessage &request, ...@@ -353,9 +357,10 @@ int32_t PsService::pull_sparse(Table *table, const PsRequestMessage &request,
return 0; return 0;
} }
int32_t PsService::push_sparse(Table *table, const PsRequestMessage &request, int32_t BrpcPsService::push_sparse(Table *table,
PsResponseMessage &response, const PsRequestMessage &request,
brpc::Controller *cntl) { PsResponseMessage &response,
brpc::Controller *cntl) {
platform::RecordEvent record_event("PsService->push_sparse"); platform::RecordEvent record_event("PsService->push_sparse");
CHECK_TABLE_EXIST(table, request, response) CHECK_TABLE_EXIST(table, request, response)
auto &push_data = request.data(); auto &push_data = request.data();
...@@ -384,10 +389,10 @@ int32_t PsService::push_sparse(Table *table, const PsRequestMessage &request, ...@@ -384,10 +389,10 @@ int32_t PsService::push_sparse(Table *table, const PsRequestMessage &request,
return 0; return 0;
} }
int32_t PsService::print_table_stat(Table *table, int32_t BrpcPsService::print_table_stat(Table *table,
const PsRequestMessage &request, const PsRequestMessage &request,
PsResponseMessage &response, PsResponseMessage &response,
brpc::Controller *cntl) { brpc::Controller *cntl) {
CHECK_TABLE_EXIST(table, request, response) CHECK_TABLE_EXIST(table, request, response)
std::pair<int64_t, int64_t> ret = table->print_table_stat(); std::pair<int64_t, int64_t> ret = table->print_table_stat();
paddle::framework::BinaryArchive ar; paddle::framework::BinaryArchive ar;
...@@ -398,9 +403,10 @@ int32_t PsService::print_table_stat(Table *table, ...@@ -398,9 +403,10 @@ int32_t PsService::print_table_stat(Table *table,
return 0; return 0;
} }
int32_t PsService::load_one_table(Table *table, const PsRequestMessage &request, int32_t BrpcPsService::load_one_table(Table *table,
PsResponseMessage &response, const PsRequestMessage &request,
brpc::Controller *cntl) { PsResponseMessage &response,
brpc::Controller *cntl) {
CHECK_TABLE_EXIST(table, request, response) CHECK_TABLE_EXIST(table, request, response)
if (request.params_size() < 2) { if (request.params_size() < 2) {
set_response_code( set_response_code(
...@@ -415,9 +421,10 @@ int32_t PsService::load_one_table(Table *table, const PsRequestMessage &request, ...@@ -415,9 +421,10 @@ int32_t PsService::load_one_table(Table *table, const PsRequestMessage &request,
return 0; return 0;
} }
int32_t PsService::load_all_table(Table *table, const PsRequestMessage &request, int32_t BrpcPsService::load_all_table(Table *table,
PsResponseMessage &response, const PsRequestMessage &request,
brpc::Controller *cntl) { PsResponseMessage &response,
brpc::Controller *cntl) {
auto &table_map = *(_server->table()); auto &table_map = *(_server->table());
for (auto &itr : table_map) { for (auto &itr : table_map) {
if (load_one_table(itr.second.get(), request, response, cntl) != 0) { if (load_one_table(itr.second.get(), request, response, cntl) != 0) {
...@@ -428,9 +435,10 @@ int32_t PsService::load_all_table(Table *table, const PsRequestMessage &request, ...@@ -428,9 +435,10 @@ int32_t PsService::load_all_table(Table *table, const PsRequestMessage &request,
return 0; return 0;
} }
int32_t PsService::save_one_table(Table *table, const PsRequestMessage &request, int32_t BrpcPsService::save_one_table(Table *table,
PsResponseMessage &response, const PsRequestMessage &request,
brpc::Controller *cntl) { PsResponseMessage &response,
brpc::Controller *cntl) {
CHECK_TABLE_EXIST(table, request, response) CHECK_TABLE_EXIST(table, request, response)
if (request.params_size() < 2) { if (request.params_size() < 2) {
set_response_code( set_response_code(
...@@ -449,9 +457,10 @@ int32_t PsService::save_one_table(Table *table, const PsRequestMessage &request, ...@@ -449,9 +457,10 @@ int32_t PsService::save_one_table(Table *table, const PsRequestMessage &request,
return feasign_size; return feasign_size;
} }
int32_t PsService::save_all_table(Table *table, const PsRequestMessage &request, int32_t BrpcPsService::save_all_table(Table *table,
PsResponseMessage &response, const PsRequestMessage &request,
brpc::Controller *cntl) { PsResponseMessage &response,
brpc::Controller *cntl) {
auto &table_map = *(_server->table()); auto &table_map = *(_server->table());
int32_t all_feasign_size = 0; int32_t all_feasign_size = 0;
int32_t feasign_size = 0; int32_t feasign_size = 0;
...@@ -466,9 +475,10 @@ int32_t PsService::save_all_table(Table *table, const PsRequestMessage &request, ...@@ -466,9 +475,10 @@ int32_t PsService::save_all_table(Table *table, const PsRequestMessage &request,
return 0; return 0;
} }
int32_t PsService::shrink_table(Table *table, const PsRequestMessage &request, int32_t BrpcPsService::shrink_table(Table *table,
PsResponseMessage &response, const PsRequestMessage &request,
brpc::Controller *cntl) { PsResponseMessage &response,
brpc::Controller *cntl) {
CHECK_TABLE_EXIST(table, request, response) CHECK_TABLE_EXIST(table, request, response)
table->flush(); table->flush();
if (table->shrink() != 0) { if (table->shrink() != 0) {
...@@ -477,20 +487,20 @@ int32_t PsService::shrink_table(Table *table, const PsRequestMessage &request, ...@@ -477,20 +487,20 @@ int32_t PsService::shrink_table(Table *table, const PsRequestMessage &request,
return 0; return 0;
} }
int32_t PsService::clear_one_table(Table *table, int32_t BrpcPsService::clear_one_table(Table *table,
const PsRequestMessage &request, const PsRequestMessage &request,
PsResponseMessage &response, PsResponseMessage &response,
brpc::Controller *cntl) { brpc::Controller *cntl) {
CHECK_TABLE_EXIST(table, request, response) CHECK_TABLE_EXIST(table, request, response)
table->flush(); table->flush();
table->clear(); table->clear();
return 0; return 0;
} }
int32_t PsService::clear_all_table(Table *table, int32_t BrpcPsService::clear_all_table(Table *table,
const PsRequestMessage &request, const PsRequestMessage &request,
PsResponseMessage &response, PsResponseMessage &response,
brpc::Controller *cntl) { brpc::Controller *cntl) {
auto &table_map = *(_server->table()); auto &table_map = *(_server->table());
for (auto &itr : table_map) { for (auto &itr : table_map) {
if (clear_one_table(itr.second.get(), request, response, cntl) != 0) { if (clear_one_table(itr.second.get(), request, response, cntl) != 0) {
...@@ -500,9 +510,10 @@ int32_t PsService::clear_all_table(Table *table, ...@@ -500,9 +510,10 @@ int32_t PsService::clear_all_table(Table *table,
return 0; return 0;
} }
int32_t PsService::stop_server(Table *table, const PsRequestMessage &request, int32_t BrpcPsService::stop_server(Table *table,
PsResponseMessage &response, const PsRequestMessage &request,
brpc::Controller *cntl) { PsResponseMessage &response,
brpc::Controller *cntl) {
auto *p_server = _server; auto *p_server = _server;
std::thread t_stop([p_server]() { std::thread t_stop([p_server]() {
p_server->stop(); p_server->stop();
...@@ -512,25 +523,27 @@ int32_t PsService::stop_server(Table *table, const PsRequestMessage &request, ...@@ -512,25 +523,27 @@ int32_t PsService::stop_server(Table *table, const PsRequestMessage &request,
return 0; return 0;
} }
int32_t PsService::stop_profiler(Table *table, const PsRequestMessage &request, int32_t BrpcPsService::stop_profiler(Table *table,
PsResponseMessage &response, const PsRequestMessage &request,
brpc::Controller *cntl) { PsResponseMessage &response,
brpc::Controller *cntl) {
platform::DisableProfiler(platform::EventSortingKey::kDefault, platform::DisableProfiler(platform::EventSortingKey::kDefault,
string::Sprintf("server_%s_profile", _rank)); string::Sprintf("server_%s_profile", _rank));
return 0; return 0;
} }
int32_t PsService::start_profiler(Table *table, const PsRequestMessage &request, int32_t BrpcPsService::start_profiler(Table *table,
PsResponseMessage &response, const PsRequestMessage &request,
brpc::Controller *cntl) { PsResponseMessage &response,
brpc::Controller *cntl) {
platform::EnableProfiler(platform::ProfilerState::kCPU); platform::EnableProfiler(platform::ProfilerState::kCPU);
return 0; return 0;
} }
int32_t PsService::push_global_step(Table *table, int32_t BrpcPsService::push_global_step(Table *table,
const PsRequestMessage &request, const PsRequestMessage &request,
PsResponseMessage &response, PsResponseMessage &response,
brpc::Controller *cntl) { brpc::Controller *cntl) {
CHECK_TABLE_EXIST(table, request, response); CHECK_TABLE_EXIST(table, request, response);
auto req_buffer_size = request.data().size(); auto req_buffer_size = request.data().size();
if (req_buffer_size < 1) { if (req_buffer_size < 1) {
......
...@@ -52,19 +52,19 @@ class BrpcPsServer : public PSServer { ...@@ -52,19 +52,19 @@ class BrpcPsServer : public PSServer {
std::vector<std::shared_ptr<brpc::Channel>> _pserver_channels; std::vector<std::shared_ptr<brpc::Channel>> _pserver_channels;
}; };
class PsService; class BrpcPsService;
typedef int32_t (PsService::*serviceHandlerFunc)( typedef int32_t (BrpcPsService::*serviceHandlerFunc)(
Table *table, const PsRequestMessage &request, PsResponseMessage &response, Table *table, const PsRequestMessage &request, PsResponseMessage &response,
brpc::Controller *cntl); brpc::Controller *cntl);
class PsService : public PsBaseService { class BrpcPsService : public PsBaseService {
public: public:
virtual int32_t initialize() override; virtual int32_t initialize() override;
virtual void service(::google::protobuf::RpcController *controller, virtual void service(::google::protobuf::RpcController *controller,
const ::paddle::PsRequestMessage *request, const PsRequestMessage *request,
::paddle::PsResponseMessage *response, PsResponseMessage *response,
::google::protobuf::Closure *done) override; ::google::protobuf::Closure *done) override;
private: private:
......
...@@ -88,7 +88,7 @@ void SerializeLodTensor(framework::Variable* var, ...@@ -88,7 +88,7 @@ void SerializeLodTensor(framework::Variable* var,
const platform::DeviceContext& ctx, VarMsg* var_msg, const platform::DeviceContext& ctx, VarMsg* var_msg,
butil::IOBuf* iobuf) { butil::IOBuf* iobuf) {
auto* tensor = var->GetMutable<framework::LoDTensor>(); auto* tensor = var->GetMutable<framework::LoDTensor>();
var_msg->set_type(::paddle::LOD_TENSOR); var_msg->set_type(::paddle::distributed::LOD_TENSOR);
const framework::LoD lod = tensor->lod(); const framework::LoD lod = tensor->lod();
if (lod.size() > 0) { if (lod.size() > 0) {
var_msg->set_lod_level(lod.size()); var_msg->set_lod_level(lod.size());
...@@ -135,7 +135,7 @@ void SerializeSelectedRows(framework::Variable* var, ...@@ -135,7 +135,7 @@ void SerializeSelectedRows(framework::Variable* var,
auto* tensor = slr->mutable_value(); auto* tensor = slr->mutable_value();
auto* rows = slr->mutable_rows(); auto* rows = slr->mutable_rows();
var_msg->set_type(::paddle::SELECTED_ROWS); var_msg->set_type(::paddle::distributed::SELECTED_ROWS);
var_msg->set_slr_height(slr->height()); var_msg->set_slr_height(slr->height());
auto* var_data = var_msg->mutable_data(); auto* var_data = var_msg->mutable_data();
...@@ -194,9 +194,9 @@ void DeserializeFromMultiVarMsgAndIOBuf(const MultiVarMsg& multi_msg, ...@@ -194,9 +194,9 @@ void DeserializeFromMultiVarMsgAndIOBuf(const MultiVarMsg& multi_msg,
++recv_var_index) { ++recv_var_index) {
const auto& msg = multi_msg.var_messages(recv_var_index); const auto& msg = multi_msg.var_messages(recv_var_index);
auto* var = scope->Var(msg.varname()); auto* var = scope->Var(msg.varname());
if (msg.type() == ::paddle::LOD_TENSOR) { if (msg.type() == ::paddle::distributed::LOD_TENSOR) {
DeserializeLodTensor(var, msg, io_buffer_itr, ctx); DeserializeLodTensor(var, msg, io_buffer_itr, ctx);
} else if (msg.type() == ::paddle::SELECTED_ROWS) { } else if (msg.type() == ::paddle::distributed::SELECTED_ROWS) {
DeserializeSelectedRows(var, msg, io_buffer_itr, ctx); DeserializeSelectedRows(var, msg, io_buffer_itr, ctx);
} }
} }
...@@ -215,9 +215,9 @@ void DeserializeFromMultiVarMsgAndIOBuf(const MultiVarMsg& multi_msg, ...@@ -215,9 +215,9 @@ void DeserializeFromMultiVarMsgAndIOBuf(const MultiVarMsg& multi_msg,
PADDLE_ENFORCE_NE(var, nullptr, PADDLE_ENFORCE_NE(var, nullptr,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"Not find variable %s in scope.", msg.varname())); "Not find variable %s in scope.", msg.varname()));
if (msg.type() == ::paddle::LOD_TENSOR) { if (msg.type() == ::paddle::distributed::LOD_TENSOR) {
DeserializeLodTensor(var, msg, io_buffer_itr, ctx); DeserializeLodTensor(var, msg, io_buffer_itr, ctx);
} else if (msg.type() == ::paddle::SELECTED_ROWS) { } else if (msg.type() == ::paddle::distributed::SELECTED_ROWS) {
DeserializeSelectedRows(var, msg, io_buffer_itr, ctx); DeserializeSelectedRows(var, msg, io_buffer_itr, ctx);
} }
} }
......
...@@ -44,8 +44,8 @@ class DeviceContext; ...@@ -44,8 +44,8 @@ class DeviceContext;
namespace paddle { namespace paddle {
namespace distributed { namespace distributed {
using MultiVarMsg = ::paddle::MultiVariableMessage; using MultiVarMsg = ::paddle::distributed::MultiVariableMessage;
using VarMsg = ::paddle::VariableMessage; using VarMsg = ::paddle::distributed::VariableMessage;
void SerializeToMultiVarMsgAndIOBuf( void SerializeToMultiVarMsgAndIOBuf(
const std::string& message_name, const std::string& message_name,
......
...@@ -122,7 +122,7 @@ void HeterClient::SendAndRecvAsync( ...@@ -122,7 +122,7 @@ void HeterClient::SendAndRecvAsync(
cntl.set_timeout_ms(FLAGS_pserver_timeout_ms); cntl.set_timeout_ms(FLAGS_pserver_timeout_ms);
distributed::MultiVarMsg request, response; distributed::MultiVarMsg request, response;
auto& request_io_buffer = cntl.request_attachment(); auto& request_io_buffer = cntl.request_attachment();
::paddle::PsService_Stub stub(xpu_channels_[num].get()); ::paddle::distributed::PsService_Stub stub(xpu_channels_[num].get());
distributed::SerializeToMultiVarMsgAndIOBuf( distributed::SerializeToMultiVarMsgAndIOBuf(
message_name_val, send_var_name_val, recv_var_name_val, *p_ctx, p_scope, message_name_val, send_var_name_val, recv_var_name_val, *p_ctx, p_scope,
&request, &request_io_buffer); &request, &request_io_buffer);
...@@ -164,7 +164,7 @@ std::future<int32_t> HeterClient::SendCmd( ...@@ -164,7 +164,7 @@ std::future<int32_t> HeterClient::SendCmd(
for (const auto& param : params) { for (const auto& param : params) {
closure->request(i)->add_params(param); closure->request(i)->add_params(param);
} }
::paddle::PsService_Stub rpc_stub(xpu_channels_[i].get()); ::paddle::distributed::PsService_Stub rpc_stub(xpu_channels_[i].get());
closure->cntl(i)->set_timeout_ms( closure->cntl(i)->set_timeout_ms(
FLAGS_pserver_timeout_ms); // cmd msg don't limit timeout for save/load FLAGS_pserver_timeout_ms); // cmd msg don't limit timeout for save/load
rpc_stub.service(closure->cntl(i), closure->request(i), rpc_stub.service(closure->cntl(i), closure->request(i),
......
...@@ -35,8 +35,8 @@ limitations under the License. */ ...@@ -35,8 +35,8 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace distributed { namespace distributed {
using MultiVarMsg = ::paddle::MultiVariableMessage; using MultiVarMsg = ::paddle::distributed::MultiVariableMessage;
using VarMsg = ::paddle::VariableMessage; using VarMsg = ::paddle::distributed::VariableMessage;
typedef std::function<void(void*)> HeterRpcCallbackFunc; typedef std::function<void(void*)> HeterRpcCallbackFunc;
......
...@@ -39,8 +39,8 @@ DECLARE_double(eager_delete_tensor_gb); ...@@ -39,8 +39,8 @@ DECLARE_double(eager_delete_tensor_gb);
namespace paddle { namespace paddle {
namespace distributed { namespace distributed {
using MultiVarMsg = ::paddle::MultiVariableMessage; using MultiVarMsg = ::paddle::distributed::MultiVariableMessage;
using VarMsg = ::paddle::VariableMessage; using VarMsg = ::paddle::distributed::VariableMessage;
class HeterService; class HeterService;
typedef int32_t (HeterService::*serviceHandlerFunc)( typedef int32_t (HeterService::*serviceHandlerFunc)(
...@@ -51,7 +51,7 @@ typedef std::function<void(void*)> HeterRpcCallbackFunc; ...@@ -51,7 +51,7 @@ typedef std::function<void(void*)> HeterRpcCallbackFunc;
typedef std::function<int(const MultiVarMsg*, MultiVarMsg*, brpc::Controller*)> typedef std::function<int(const MultiVarMsg*, MultiVarMsg*, brpc::Controller*)>
HeterServiceHandler; HeterServiceHandler;
class HeterService : public ::paddle::PsService { class HeterService : public ::paddle::distributed::PsService {
public: public:
HeterService() { HeterService() {
_service_handler_map[PS_STOP_SERVER] = &HeterService::stop_heter_worker; _service_handler_map[PS_STOP_SERVER] = &HeterService::stop_heter_worker;
...@@ -62,8 +62,8 @@ class HeterService : public ::paddle::PsService { ...@@ -62,8 +62,8 @@ class HeterService : public ::paddle::PsService {
virtual ~HeterService() {} virtual ~HeterService() {}
virtual void service(::google::protobuf::RpcController* controller, virtual void service(::google::protobuf::RpcController* controller,
const ::paddle::PsRequestMessage* request, const PsRequestMessage* request,
::paddle::PsResponseMessage* response, PsResponseMessage* response,
::google::protobuf::Closure* done) { ::google::protobuf::Closure* done) {
brpc::ClosureGuard done_guard(done); brpc::ClosureGuard done_guard(done);
std::string log_label("ReceiveCmd-"); std::string log_label("ReceiveCmd-");
......
...@@ -13,9 +13,7 @@ ...@@ -13,9 +13,7 @@
// limitations under the License. // limitations under the License.
#include "paddle/fluid/distributed/service/ps_client.h" #include "paddle/fluid/distributed/service/ps_client.h"
#include <map> #include <map>
#include "brpc/server.h" #include "brpc/server.h"
#include "glog/logging.h" #include "glog/logging.h"
#include "paddle/fluid/distributed/service/brpc_ps_client.h" #include "paddle/fluid/distributed/service/brpc_ps_client.h"
...@@ -23,7 +21,7 @@ ...@@ -23,7 +21,7 @@
namespace paddle { namespace paddle {
namespace distributed { namespace distributed {
REGISTER_CLASS(PSClient, BrpcPsClient); REGISTER_PSCORE_CLASS(PSClient, BrpcPsClient);
int32_t PSClient::configure( int32_t PSClient::configure(
const PSParameter &config, const PSParameter &config,
...@@ -43,7 +41,7 @@ int32_t PSClient::configure( ...@@ -43,7 +41,7 @@ int32_t PSClient::configure(
const auto &work_param = _config.worker_param().downpour_worker_param(); const auto &work_param = _config.worker_param().downpour_worker_param();
for (size_t i = 0; i < work_param.downpour_table_param_size(); ++i) { for (size_t i = 0; i < work_param.downpour_table_param_size(); ++i) {
auto *accessor = CREATE_CLASS( auto *accessor = CREATE_PSCORE_CLASS(
ValueAccessor, ValueAccessor,
work_param.downpour_table_param(i).accessor().accessor_class()); work_param.downpour_table_param(i).accessor().accessor_class());
accessor->configure(work_param.downpour_table_param(i).accessor()); accessor->configure(work_param.downpour_table_param(i).accessor());
...@@ -73,7 +71,8 @@ PSClient *PSClientFactory::create(const PSParameter &ps_config) { ...@@ -73,7 +71,8 @@ PSClient *PSClientFactory::create(const PSParameter &ps_config) {
} }
const auto &service_param = config.downpour_server_param().service_param(); const auto &service_param = config.downpour_server_param().service_param();
PSClient *client = CREATE_CLASS(PSClient, service_param.client_class()); PSClient *client =
CREATE_PSCORE_CLASS(PSClient, service_param.client_class());
if (client == NULL) { if (client == NULL) {
LOG(ERROR) << "client is not registered, server_name:" LOG(ERROR) << "client is not registered, server_name:"
<< service_param.client_class(); << service_param.client_class();
......
...@@ -28,6 +28,9 @@ ...@@ -28,6 +28,9 @@
namespace paddle { namespace paddle {
namespace distributed { namespace distributed {
using paddle::distributed::PsRequestMessage;
using paddle::distributed::PsResponseMessage;
typedef std::function<void(void *)> PSClientCallBack; typedef std::function<void(void *)> PSClientCallBack;
class PSClientClosure : public google::protobuf::Closure { class PSClientClosure : public google::protobuf::Closure {
public: public:
...@@ -206,7 +209,7 @@ class PSClient { ...@@ -206,7 +209,7 @@ class PSClient {
std::unordered_map<int32_t, MsgHandlerFunc> std::unordered_map<int32_t, MsgHandlerFunc>
_msg_handler_map; //处理client2client消息 _msg_handler_map; //处理client2client消息
}; };
REGISTER_REGISTERER(PSClient); REGISTER_PSCORE_REGISTERER(PSClient);
class PSClientFactory { class PSClientFactory {
public: public:
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
// limitations under the License. // limitations under the License.
syntax = "proto2"; syntax = "proto2";
package paddle; package paddle.distributed;
option cc_generic_services = true; option cc_generic_services = true;
option cc_enable_arenas = true; option cc_enable_arenas = true;
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
// limitations under the License. // limitations under the License.
#include "paddle/fluid/distributed/service/server.h" #include "paddle/fluid/distributed/service/server.h"
#include "glog/logging.h" #include "glog/logging.h"
#include "paddle/fluid/distributed/service/brpc_ps_server.h" #include "paddle/fluid/distributed/service/brpc_ps_server.h"
#include "paddle/fluid/distributed/table/table.h" #include "paddle/fluid/distributed/table/table.h"
...@@ -20,8 +21,8 @@ ...@@ -20,8 +21,8 @@
namespace paddle { namespace paddle {
namespace distributed { namespace distributed {
REGISTER_CLASS(PSServer, BrpcPsServer); REGISTER_PSCORE_CLASS(PSServer, BrpcPsServer);
REGISTER_CLASS(PsBaseService, PsService); REGISTER_PSCORE_CLASS(PsBaseService, BrpcPsService);
PSServer *PSServerFactory::create(const PSParameter &ps_config) { PSServer *PSServerFactory::create(const PSParameter &ps_config) {
const auto &config = ps_config.server_param(); const auto &config = ps_config.server_param();
...@@ -43,7 +44,8 @@ PSServer *PSServerFactory::create(const PSParameter &ps_config) { ...@@ -43,7 +44,8 @@ PSServer *PSServerFactory::create(const PSParameter &ps_config) {
} }
const auto &service_param = config.downpour_server_param().service_param(); const auto &service_param = config.downpour_server_param().service_param();
PSServer *server = CREATE_CLASS(PSServer, service_param.server_class()); PSServer *server =
CREATE_PSCORE_CLASS(PSServer, service_param.server_class());
if (server == NULL) { if (server == NULL) {
LOG(ERROR) << "server is not registered, server_name:" LOG(ERROR) << "server is not registered, server_name:"
<< service_param.server_class(); << service_param.server_class();
...@@ -70,7 +72,7 @@ int32_t PSServer::configure( ...@@ -70,7 +72,7 @@ int32_t PSServer::configure(
uint32_t global_step_table = UINT32_MAX; uint32_t global_step_table = UINT32_MAX;
for (size_t i = 0; i < downpour_param.downpour_table_param_size(); ++i) { for (size_t i = 0; i < downpour_param.downpour_table_param_size(); ++i) {
auto *table = CREATE_CLASS( auto *table = CREATE_PSCORE_CLASS(
Table, downpour_param.downpour_table_param(i).table_class()); Table, downpour_param.downpour_table_param(i).table_class());
if (downpour_param.downpour_table_param(i).table_class() == if (downpour_param.downpour_table_param(i).table_class() ==
......
...@@ -46,6 +46,8 @@ namespace paddle { ...@@ -46,6 +46,8 @@ namespace paddle {
namespace distributed { namespace distributed {
class Table; class Table;
using paddle::distributed::PsRequestMessage;
using paddle::distributed::PsResponseMessage;
class PSServer { class PSServer {
public: public:
...@@ -107,7 +109,7 @@ class PSServer { ...@@ -107,7 +109,7 @@ class PSServer {
platform::Place place_ = platform::CPUPlace(); platform::Place place_ = platform::CPUPlace();
}; };
REGISTER_REGISTERER(PSServer); REGISTER_PSCORE_REGISTERER(PSServer);
typedef std::function<void(void *)> PServerCallBack; typedef std::function<void(void *)> PServerCallBack;
...@@ -141,8 +143,8 @@ class PsBaseService : public PsService { ...@@ -141,8 +143,8 @@ class PsBaseService : public PsService {
return 0; return 0;
} }
virtual void service(::google::protobuf::RpcController *controller, virtual void service(::google::protobuf::RpcController *controller,
const ::paddle::PsRequestMessage *request, const PsRequestMessage *request,
::paddle::PsResponseMessage *response, PsResponseMessage *response,
::google::protobuf::Closure *done) override = 0; ::google::protobuf::Closure *done) override = 0;
virtual void set_response_code(PsResponseMessage &response, int err_code, virtual void set_response_code(PsResponseMessage &response, int err_code,
...@@ -159,7 +161,7 @@ class PsBaseService : public PsService { ...@@ -159,7 +161,7 @@ class PsBaseService : public PsService {
PSServer *_server; PSServer *_server;
const ServerParameter *_config; const ServerParameter *_config;
}; };
REGISTER_REGISTERER(PsBaseService); REGISTER_PSCORE_REGISTERER(PsBaseService);
class PSServerFactory { class PSServerFactory {
public: public:
......
...@@ -28,6 +28,10 @@ limitations under the License. */ ...@@ -28,6 +28,10 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace distributed { namespace distributed {
using paddle::distributed::PsRequestMessage;
using paddle::distributed::PsResponseMessage;
using paddle::distributed::PsService;
class PSCore { class PSCore {
public: public:
explicit PSCore() {} explicit PSCore() {}
......
...@@ -165,6 +165,6 @@ class ValueAccessor { ...@@ -165,6 +165,6 @@ class ValueAccessor {
std::unordered_map<int, std::shared_ptr<struct DataConverter>> std::unordered_map<int, std::shared_ptr<struct DataConverter>>
_data_coverter_map; _data_coverter_map;
}; };
REGISTER_REGISTERER(ValueAccessor); REGISTER_PSCORE_REGISTERER(ValueAccessor);
} // namespace distributed } // namespace distributed
} // namespace paddle } // namespace paddle
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
// limitations under the License. // limitations under the License.
#include "paddle/fluid/distributed/table/table.h" #include "paddle/fluid/distributed/table/table.h"
#include <boost/preprocessor/repetition/repeat_from_to.hpp> #include <boost/preprocessor/repetition/repeat_from_to.hpp>
#include <boost/preprocessor/seq/elem.hpp> #include <boost/preprocessor/seq/elem.hpp>
#include "glog/logging.h" #include "glog/logging.h"
...@@ -27,14 +28,14 @@ ...@@ -27,14 +28,14 @@
namespace paddle { namespace paddle {
namespace distributed { namespace distributed {
REGISTER_CLASS(Table, CommonDenseTable); REGISTER_PSCORE_CLASS(Table, CommonDenseTable);
REGISTER_CLASS(Table, CommonSparseTable); REGISTER_PSCORE_CLASS(Table, CommonSparseTable);
REGISTER_CLASS(Table, SparseGeoTable); REGISTER_PSCORE_CLASS(Table, SparseGeoTable);
REGISTER_CLASS(Table, BarrierTable); REGISTER_PSCORE_CLASS(Table, BarrierTable);
REGISTER_CLASS(Table, TensorTable); REGISTER_PSCORE_CLASS(Table, TensorTable);
REGISTER_CLASS(Table, DenseTensorTable); REGISTER_PSCORE_CLASS(Table, DenseTensorTable);
REGISTER_CLASS(Table, GlobalStepTable); REGISTER_PSCORE_CLASS(Table, GlobalStepTable);
REGISTER_CLASS(ValueAccessor, CommMergeAccessor); REGISTER_PSCORE_CLASS(ValueAccessor, CommMergeAccessor);
int32_t TableManager::initialize() { int32_t TableManager::initialize() {
static bool initialized = false; static bool initialized = false;
...@@ -61,9 +62,9 @@ int32_t Table::initialize_accessor() { ...@@ -61,9 +62,9 @@ int32_t Table::initialize_accessor() {
<< _config.table_id(); << _config.table_id();
return -1; return -1;
} }
auto *accessor = auto *accessor = CREATE_PSCORE_CLASS(
CREATE_CLASS(ValueAccessor, ValueAccessor,
_config.accessor().accessor_class()) if (accessor == NULL) { _config.accessor().accessor_class()) if (accessor == NULL) {
LOG(ERROR) << "accessor is unregisteg, table_id:" << _config.table_id() LOG(ERROR) << "accessor is unregisteg, table_id:" << _config.table_id()
<< ", accessor_name:" << _config.accessor().accessor_class(); << ", accessor_name:" << _config.accessor().accessor_class();
return -1; return -1;
......
...@@ -127,7 +127,7 @@ class Table { ...@@ -127,7 +127,7 @@ class Table {
float *_global_lr = nullptr; float *_global_lr = nullptr;
std::shared_ptr<ValueAccessor> _value_accesor; std::shared_ptr<ValueAccessor> _value_accesor;
}; };
REGISTER_REGISTERER(Table); REGISTER_PSCORE_REGISTERER(Table);
class TableManager { class TableManager {
public: public:
......
...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include <unistd.h> #include <unistd.h>
#include <condition_variable> // NOLINT #include <condition_variable> // NOLINT
#include <string> #include <string>
#include <thread> // NOLINT #include <thread> // NOLINT
...@@ -94,7 +95,7 @@ void GetDownpourDenseTableProto( ...@@ -94,7 +95,7 @@ void GetDownpourDenseTableProto(
server_proto->mutable_downpour_server_param(); server_proto->mutable_downpour_server_param();
::paddle::distributed::ServerServiceParameter* server_service_proto = ::paddle::distributed::ServerServiceParameter* server_service_proto =
downpour_server_proto->mutable_service_param(); downpour_server_proto->mutable_service_param();
server_service_proto->set_service_class("PsService"); server_service_proto->set_service_class("BrpcPsService");
server_service_proto->set_server_class("BrpcPsServer"); server_service_proto->set_server_class("BrpcPsServer");
server_service_proto->set_client_class("BrpcPsClient"); server_service_proto->set_client_class("BrpcPsClient");
server_service_proto->set_start_server_port(0); server_service_proto->set_start_server_port(0);
...@@ -124,7 +125,7 @@ void GetDownpourDenseTableProto( ...@@ -124,7 +125,7 @@ void GetDownpourDenseTableProto(
server_proto->mutable_downpour_server_param(); server_proto->mutable_downpour_server_param();
::paddle::distributed::ServerServiceParameter* server_service_proto = ::paddle::distributed::ServerServiceParameter* server_service_proto =
downpour_server_proto->mutable_service_param(); downpour_server_proto->mutable_service_param();
server_service_proto->set_service_class("PsService"); server_service_proto->set_service_class("BrpcPsService");
server_service_proto->set_server_class("BrpcPsServer"); server_service_proto->set_server_class("BrpcPsServer");
server_service_proto->set_client_class("BrpcPsClient"); server_service_proto->set_client_class("BrpcPsClient");
server_service_proto->set_start_server_port(0); server_service_proto->set_start_server_port(0);
...@@ -244,7 +245,8 @@ void RunBrpcPushDense() { ...@@ -244,7 +245,8 @@ void RunBrpcPushDense() {
int ret = 0; int ret = 0;
auto* closure = (paddle::distributed::DownpourBrpcClosure*)done; auto* closure = (paddle::distributed::DownpourBrpcClosure*)done;
for (size_t i = 0; i < 1; ++i) { for (size_t i = 0; i < 1; ++i) {
if (closure->check_response(i, paddle::PS_PUSH_DENSE_TABLE) != 0) { if (closure->check_response(
i, paddle::distributed::PS_PUSH_DENSE_TABLE) != 0) {
ret = -1; ret = -1;
break; break;
} }
......
...@@ -18,6 +18,7 @@ limitations under the License. */ ...@@ -18,6 +18,7 @@ limitations under the License. */
#include <thread> // NOLINT #include <thread> // NOLINT
#include "google/protobuf/text_format.h" #include "google/protobuf/text_format.h"
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/scope.h"
...@@ -94,7 +95,7 @@ void GetDownpourSparseTableProto( ...@@ -94,7 +95,7 @@ void GetDownpourSparseTableProto(
server_proto->mutable_downpour_server_param(); server_proto->mutable_downpour_server_param();
::paddle::distributed::ServerServiceParameter* server_service_proto = ::paddle::distributed::ServerServiceParameter* server_service_proto =
downpour_server_proto->mutable_service_param(); downpour_server_proto->mutable_service_param();
server_service_proto->set_service_class("PsService"); server_service_proto->set_service_class("BrpcPsService");
server_service_proto->set_server_class("BrpcPsServer"); server_service_proto->set_server_class("BrpcPsServer");
server_service_proto->set_client_class("BrpcPsClient"); server_service_proto->set_client_class("BrpcPsClient");
server_service_proto->set_start_server_port(0); server_service_proto->set_start_server_port(0);
...@@ -124,7 +125,7 @@ void GetDownpourSparseTableProto( ...@@ -124,7 +125,7 @@ void GetDownpourSparseTableProto(
server_proto->mutable_downpour_server_param(); server_proto->mutable_downpour_server_param();
::paddle::distributed::ServerServiceParameter* server_service_proto = ::paddle::distributed::ServerServiceParameter* server_service_proto =
downpour_server_proto->mutable_service_param(); downpour_server_proto->mutable_service_param();
server_service_proto->set_service_class("PsService"); server_service_proto->set_service_class("BrpcPsService");
server_service_proto->set_server_class("BrpcPsServer"); server_service_proto->set_server_class("BrpcPsServer");
server_service_proto->set_client_class("BrpcPsClient"); server_service_proto->set_client_class("BrpcPsClient");
server_service_proto->set_start_server_port(0); server_service_proto->set_start_server_port(0);
...@@ -225,7 +226,8 @@ void RunBrpcPushSparse() { ...@@ -225,7 +226,8 @@ void RunBrpcPushSparse() {
int ret = 0; int ret = 0;
auto* closure = (paddle::distributed::DownpourBrpcClosure*)done; auto* closure = (paddle::distributed::DownpourBrpcClosure*)done;
for (size_t i = 0; i < 1; ++i) { for (size_t i = 0; i < 1; ++i) {
if (closure->check_response(i, paddle::PS_PUSH_SPARSE_PARAM) != 0) { if (closure->check_response(
i, paddle::distributed::PS_PUSH_SPARSE_PARAM) != 0) {
ret = -1; ret = -1;
break; break;
} }
...@@ -252,7 +254,8 @@ void RunBrpcPushSparse() { ...@@ -252,7 +254,8 @@ void RunBrpcPushSparse() {
int ret = 0; int ret = 0;
auto* closure = (paddle::distributed::DownpourBrpcClosure*)done; auto* closure = (paddle::distributed::DownpourBrpcClosure*)done;
for (size_t i = 0; i < 1; ++i) { for (size_t i = 0; i < 1; ++i) {
if (closure->check_response(i, paddle::PS_PUSH_SPARSE_TABLE) != 0) { if (closure->check_response(
i, paddle::distributed::PS_PUSH_SPARSE_TABLE) != 0) {
ret = -1; ret = -1;
break; break;
} }
......
...@@ -75,7 +75,7 @@ void RunMultiVarMsg(platform::Place place) { ...@@ -75,7 +75,7 @@ void RunMultiVarMsg(platform::Place place) {
auto& ctx = *pool.Get(place); auto& ctx = *pool.Get(place);
CreateVarsOnScope(&scope, &place, ctx); CreateVarsOnScope(&scope, &place, ctx);
::paddle::MultiVariableMessage multi_msg; ::paddle::distributed::MultiVariableMessage multi_msg;
std::string message_name("se_de_test"); std::string message_name("se_de_test");
std::vector<std::string> send_var_name = {"x1", "x2", "x3"}; std::vector<std::string> send_var_name = {"x1", "x2", "x3"};
std::vector<std::string> recv_var_name = {}; std::vector<std::string> recv_var_name = {};
...@@ -138,4 +138,4 @@ TEST(MultiVarMsgCPU, Run) { ...@@ -138,4 +138,4 @@ TEST(MultiVarMsgCPU, Run) {
// platform::CUDAPlace place; // platform::CUDAPlace place;
// RunMultiVarMsg(place); // RunMultiVarMsg(place);
// } // }
// #endif // #endif
\ No newline at end of file
...@@ -209,12 +209,12 @@ if(WITH_DISTRIBUTE) ...@@ -209,12 +209,12 @@ if(WITH_DISTRIBUTE)
pull_dense_worker.cc section_worker.cc device_worker_factory.cc data_set.cc DEPS op_registry pull_dense_worker.cc section_worker.cc device_worker_factory.cc data_set.cc DEPS op_registry
device_context scope framework_proto trainer_desc_proto glog fs shell device_context scope framework_proto trainer_desc_proto glog fs shell
fleet_wrapper heter_wrapper ps_gpu_wrapper box_wrapper lodtensor_printer fleet_wrapper heter_wrapper ps_gpu_wrapper box_wrapper lodtensor_printer
lod_rank_table feed_fetch_method sendrecvop_rpc communicator collective_helper ${GLOB_DISTRIBUTE_DEPS} lod_rank_table feed_fetch_method collective_helper ${GLOB_DISTRIBUTE_DEPS}
graph_to_program_pass variable_helper data_feed_proto timer monitor graph_to_program_pass variable_helper data_feed_proto timer monitor
heter_service_proto pslib_brpc) heter_service_proto pslib_brpc)
set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor") set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor")
set_source_files_properties(executor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(executor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
else() elseif(WITH_PSCORE)
cc_library(executor SRCS executor.cc multi_trainer.cc pipeline_trainer.cc dataset_factory.cc cc_library(executor SRCS executor.cc multi_trainer.cc pipeline_trainer.cc dataset_factory.cc
dist_multi_trainer.cc trainer_factory.cc trainer.cc data_feed_factory.cc dist_multi_trainer.cc trainer_factory.cc trainer.cc data_feed_factory.cc
heterxpu_trainer.cc heterxpu_trainer.cc
...@@ -230,6 +230,16 @@ if(WITH_DISTRIBUTE) ...@@ -230,6 +230,16 @@ if(WITH_DISTRIBUTE)
set_source_files_properties(executor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(executor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(multi_trainer.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(multi_trainer.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(hogwild_worker.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(hogwild_worker.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
else()
cc_library(executor SRCS executor.cc multi_trainer.cc pipeline_trainer.cc dataset_factory.cc
dist_multi_trainer.cc trainer_factory.cc trainer.cc data_feed_factory.cc
heterxpu_trainer.cc
data_feed.cc device_worker.cc hogwild_worker.cc hetercpu_worker.cc ps_gpu_worker.cc
heterbox_worker.cc heterbox_trainer.cc ps_gpu_trainer.cc downpour_worker.cc downpour_worker_opt.cc
pull_dense_worker.cc section_worker.cc device_worker_factory.cc data_set.cc DEPS op_registry
device_context scope framework_proto data_feed_proto heter_service_proto trainer_desc_proto glog
lod_rank_table fs shell fleet_wrapper heter_wrapper ps_gpu_wrapper box_wrapper lodtensor_printer feed_fetch_method
graph_to_program_pass variable_helper timer monitor)
endif() endif()
elseif(WITH_PSLIB) elseif(WITH_PSLIB)
cc_library(executor SRCS executor.cc multi_trainer.cc pipeline_trainer.cc dataset_factory.cc cc_library(executor SRCS executor.cc multi_trainer.cc pipeline_trainer.cc dataset_factory.cc
...@@ -241,7 +251,6 @@ elseif(WITH_PSLIB) ...@@ -241,7 +251,6 @@ elseif(WITH_PSLIB)
device_context scope framework_proto data_feed_proto heter_service_proto trainer_desc_proto glog device_context scope framework_proto data_feed_proto heter_service_proto trainer_desc_proto glog
lod_rank_table fs shell fleet_wrapper heter_wrapper ps_gpu_wrapper box_wrapper lodtensor_printer feed_fetch_method lod_rank_table fs shell fleet_wrapper heter_wrapper ps_gpu_wrapper box_wrapper lodtensor_printer feed_fetch_method
graph_to_program_pass variable_helper timer monitor pslib_brpc ) graph_to_program_pass variable_helper timer monitor pslib_brpc )
else() else()
cc_library(executor SRCS executor.cc multi_trainer.cc pipeline_trainer.cc dataset_factory.cc cc_library(executor SRCS executor.cc multi_trainer.cc pipeline_trainer.cc dataset_factory.cc
dist_multi_trainer.cc trainer_factory.cc trainer.cc data_feed_factory.cc dist_multi_trainer.cc trainer_factory.cc trainer.cc data_feed_factory.cc
......
...@@ -14,7 +14,7 @@ cc_library(multi_devices_helper SRCS multi_devices_helper.cc DEPS graph graph_he ...@@ -14,7 +14,7 @@ cc_library(multi_devices_helper SRCS multi_devices_helper.cc DEPS graph graph_he
cc_library(variable_visitor SRCS variable_visitor.cc DEPS lod_tensor selected_rows) cc_library(variable_visitor SRCS variable_visitor.cc DEPS lod_tensor selected_rows)
if(WITH_DISTRIBUTE) if(WITH_PSCORE)
set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor") set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor")
set_source_files_properties(reduce_op_handle.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(reduce_op_handle.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(threaded_ssa_graph_executor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(threaded_ssa_graph_executor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
#include "paddle/fluid/framework/variable_helper.h" #include "paddle/fluid/framework/variable_helper.h"
#ifdef PADDLE_WITH_DISTRIBUTE #if defined PADDLE_WITH_PSCORE
#include "paddle/fluid/distributed/service/communicator.h" #include "paddle/fluid/distributed/service/communicator.h"
#endif #endif
...@@ -138,7 +138,7 @@ FetchResultType AsyncSSAGraphExecutor::Run( ...@@ -138,7 +138,7 @@ FetchResultType AsyncSSAGraphExecutor::Run(
"results to be fetched!")); "results to be fetched!"));
// init once // init once
if (run_futures_.size() == 0 && places_.size() > 1) { if (run_futures_.size() == 0 && places_.size() > 1) {
#ifdef PADDLE_WITH_DISTRIBUTE #if defined PADDLE_WITH_PSCORE
if (strategy_.thread_barrier_) { if (strategy_.thread_barrier_) {
paddle::distributed::Communicator::GetInstance()->BarrierTriggerReset( paddle::distributed::Communicator::GetInstance()->BarrierTriggerReset(
places_.size()); places_.size());
......
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
#include "paddle/fluid/framework/ir/graph_helper.h" #include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/platform/profiler.h" #include "paddle/fluid/platform/profiler.h"
#ifdef PADDLE_WITH_DISTRIBUTE #if defined PADDLE_WITH_PSCORE
#include "paddle/fluid/distributed/service/communicator.h" #include "paddle/fluid/distributed/service/communicator.h"
#endif #endif
...@@ -360,7 +360,7 @@ bool ThreadedSSAGraphExecutor::RunOpSync(OpHandleBase *op) { ...@@ -360,7 +360,7 @@ bool ThreadedSSAGraphExecutor::RunOpSync(OpHandleBase *op) {
void ThreadedSSAGraphExecutor::ExecutionFinal( void ThreadedSSAGraphExecutor::ExecutionFinal(
std::vector<OpHandleBase *> *fetch_ops) { std::vector<OpHandleBase *> *fetch_ops) {
#ifdef PADDLE_WITH_DISTRIBUTE #if defined PADDLE_WITH_PSCORE
if (strategy_.thread_barrier_) { if (strategy_.thread_barrier_) {
paddle::distributed::Communicator::GetInstance()->BarrierTriggerDecrement(); paddle::distributed::Communicator::GetInstance()->BarrierTriggerDecrement();
} }
......
...@@ -19,7 +19,7 @@ limitations under the License. */ ...@@ -19,7 +19,7 @@ limitations under the License. */
#include "paddle/fluid/platform/cpu_helper.h" #include "paddle/fluid/platform/cpu_helper.h"
#include "paddle/fluid/platform/lodtensor_printer.h" #include "paddle/fluid/platform/lodtensor_printer.h"
#ifdef PADDLE_WITH_DISTRIBUTE #if defined PADDLE_WITH_PSCORE
#include "paddle/fluid/distributed/service/communicator.h" #include "paddle/fluid/distributed/service/communicator.h"
#endif #endif
...@@ -186,7 +186,7 @@ void HogwildWorker::TrainFilesWithProfiler() { ...@@ -186,7 +186,7 @@ void HogwildWorker::TrainFilesWithProfiler() {
writer_.Flush(); writer_.Flush();
} }
#ifdef PADDLE_WITH_DISTRIBUTE #if defined PADDLE_WITH_PSCORE
if (thread_barrier_) { if (thread_barrier_) {
paddle::distributed::Communicator::GetInstance()->BarrierTriggerDecrement(); paddle::distributed::Communicator::GetInstance()->BarrierTriggerDecrement();
} }
...@@ -216,7 +216,7 @@ void HogwildWorker::TrainFiles() { ...@@ -216,7 +216,7 @@ void HogwildWorker::TrainFiles() {
PrintFetchVars(); PrintFetchVars();
thread_scope_->DropKids(); thread_scope_->DropKids();
} }
#ifdef PADDLE_WITH_DISTRIBUTE #if defined PADDLE_WITH_PSCORE
if (thread_barrier_) { if (thread_barrier_) {
paddle::distributed::Communicator::GetInstance()->BarrierTriggerDecrement(); paddle::distributed::Communicator::GetInstance()->BarrierTriggerDecrement();
} }
......
...@@ -18,7 +18,7 @@ limitations under the License. */ ...@@ -18,7 +18,7 @@ limitations under the License. */
#include "paddle/fluid/framework/device_worker_factory.h" #include "paddle/fluid/framework/device_worker_factory.h"
#include "paddle/fluid/framework/trainer.h" #include "paddle/fluid/framework/trainer.h"
#ifdef PADDLE_WITH_DISTRIBUTE #if defined PADDLE_WITH_PSCORE
#include "paddle/fluid/distributed/service/communicator.h" #include "paddle/fluid/distributed/service/communicator.h"
#endif #endif
...@@ -49,7 +49,7 @@ void MultiTrainer::Initialize(const TrainerDesc& trainer_desc, ...@@ -49,7 +49,7 @@ void MultiTrainer::Initialize(const TrainerDesc& trainer_desc,
VLOG(3) << "worker thread num: " << thread_num_; VLOG(3) << "worker thread num: " << thread_num_;
workers_.resize(thread_num_); workers_.resize(thread_num_);
#ifdef PADDLE_WITH_DISTRIBUTE #if defined PADDLE_WITH_PSCORE
if (trainer_desc.thread_barrier()) { if (trainer_desc.thread_barrier()) {
paddle::distributed::Communicator::GetInstance()->BarrierTriggerReset( paddle::distributed::Communicator::GetInstance()->BarrierTriggerReset(
thread_num_); thread_num_);
......
...@@ -77,12 +77,12 @@ set(SHARED_INFERENCE_SRCS ...@@ -77,12 +77,12 @@ set(SHARED_INFERENCE_SRCS
${mkldnn_quantizer_src_file}) ${mkldnn_quantizer_src_file})
# Create shared inference library defaultly # Create shared inference library defaultly
if(NOT WITH_DISTRIBUTE) if(NOT WITH_PSCORE)
cc_library(paddle_fluid_shared SHARED SRCS ${SHARED_INFERENCE_SRCS} cc_library(paddle_fluid_shared SHARED SRCS ${SHARED_INFERENCE_SRCS}
DEPS ${fluid_modules} analysis_predictor) DEPS ${fluid_modules} analysis_predictor)
else() else()
cc_library(paddle_fluid_shared SHARED SRCS ${SHARED_INFERENCE_SRCS} cc_library(paddle_fluid_shared SHARED SRCS ${SHARED_INFERENCE_SRCS}
DEPS ${fluid_modules} analysis_predictor fleet ps_service) DEPS ${fluid_modules} analysis_predictor fleet ps_service)
endif() endif()
get_property(os_dependency_modules GLOBAL PROPERTY OS_DEPENDENCY_MODULES) get_property(os_dependency_modules GLOBAL PROPERTY OS_DEPENDENCY_MODULES)
......
...@@ -22,10 +22,13 @@ add_subdirectory(jit) ...@@ -22,10 +22,13 @@ add_subdirectory(jit)
if(WITH_DISTRIBUTE) if(WITH_DISTRIBUTE)
add_subdirectory(pscore)
add_subdirectory(collective) add_subdirectory(collective)
endif() endif()
if (WITH_PSCORE)
add_subdirectory(pscore)
endif()
add_subdirectory(amp) add_subdirectory(amp)
add_subdirectory(reader) add_subdirectory(reader)
......
if (WITH_PSLIB)
return()
endif()
include(operators) include(operators)
set(DISTRIBUTE_DEPS "") set(DISTRIBUTE_DEPS "")
......
...@@ -46,8 +46,8 @@ class DeviceContext; ...@@ -46,8 +46,8 @@ class DeviceContext;
namespace paddle { namespace paddle {
namespace operators { namespace operators {
using MultiVarMsg = ::paddle::MultiVariableMessage; using MultiVarMsg = ::paddle::distributed::MultiVariableMessage;
using VarMsg = ::paddle::VariableMessage; using VarMsg = ::paddle::distributed::VariableMessage;
template <class TKey, class TValue> template <class TKey, class TValue>
class DoubleFindMap : public std::unordered_map<TKey, TValue> { class DoubleFindMap : public std::unordered_map<TKey, TValue> {
......
...@@ -36,8 +36,8 @@ namespace framework = paddle::framework; ...@@ -36,8 +36,8 @@ namespace framework = paddle::framework;
namespace platform = paddle::platform; namespace platform = paddle::platform;
namespace distributed = paddle::distributed; namespace distributed = paddle::distributed;
using MultiVarMsg = ::paddle::MultiVariableMessage; using MultiVarMsg = ::paddle::distributed::MultiVariableMessage;
using VarMsg = ::paddle::VariableMessage; using VarMsg = ::paddle::distributed::VariableMessage;
DECLARE_double(eager_delete_tensor_gb); DECLARE_double(eager_delete_tensor_gb);
USE_OP(scale); USE_OP(scale);
......
...@@ -32,8 +32,8 @@ namespace framework = paddle::framework; ...@@ -32,8 +32,8 @@ namespace framework = paddle::framework;
namespace platform = paddle::platform; namespace platform = paddle::platform;
namespace distributed = paddle::distributed; namespace distributed = paddle::distributed;
using MultiVarMsg = ::paddle::MultiVariableMessage; using MultiVarMsg = ::paddle::distributed::MultiVariableMessage;
using VarMsg = ::paddle::VariableMessage; using VarMsg = ::paddle::distributed::VariableMessage;
USE_OP(scale); USE_OP(scale);
......
...@@ -49,7 +49,7 @@ if (WITH_CRYPTO) ...@@ -49,7 +49,7 @@ if (WITH_CRYPTO)
set(PYBIND_SRCS ${PYBIND_SRCS} crypto.cc) set(PYBIND_SRCS ${PYBIND_SRCS} crypto.cc)
endif (WITH_CRYPTO) endif (WITH_CRYPTO)
if (WITH_DISTRIBUTE) if (WITH_PSCORE)
set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor -Wno-error=sign-compare -Wno-error=unused-variable -Wno-error=return-type -Wno-error=unused-but-set-variable -Wno-error=type-limits -Wno-error=unknown-pragmas -Wno-error=parentheses -Wno-error=unused-result") set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor -Wno-error=sign-compare -Wno-error=unused-variable -Wno-error=return-type -Wno-error=unused-but-set-variable -Wno-error=type-limits -Wno-error=unknown-pragmas -Wno-error=parentheses -Wno-error=unused-result")
set_source_files_properties(fleet_py.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(fleet_py.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
list(APPEND PYBIND_DEPS fleet communicator) list(APPEND PYBIND_DEPS fleet communicator)
......
...@@ -106,7 +106,7 @@ limitations under the License. */ ...@@ -106,7 +106,7 @@ limitations under the License. */
#include "paddle/fluid/pybind/crypto.h" #include "paddle/fluid/pybind/crypto.h"
#endif #endif
#ifdef PADDLE_WITH_DISTRIBUTE #if defined PADDLE_WITH_PSCORE
#include "paddle/fluid/pybind/fleet_py.h" #include "paddle/fluid/pybind/fleet_py.h"
#endif #endif
...@@ -2833,7 +2833,7 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -2833,7 +2833,7 @@ All parameter, weight, gradient are variables in Paddle.
BindCrypto(&m); BindCrypto(&m);
#endif #endif
#ifdef PADDLE_WITH_DISTRIBUTE #if defined PADDLE_WITH_PSCORE
BindDistFleetWrapper(&m); BindDistFleetWrapper(&m);
BindPSHost(&m); BindPSHost(&m);
BindCommunicatorContext(&m); BindCommunicatorContext(&m);
......
...@@ -236,7 +236,8 @@ function cmake_base() { ...@@ -236,7 +236,8 @@ function cmake_base() {
-DPY_VERSION=${PY_VERSION:-2.7} -DPY_VERSION=${PY_VERSION:-2.7}
-DCMAKE_INSTALL_PREFIX=${INSTALL_PREFIX:-/paddle/build} -DCMAKE_INSTALL_PREFIX=${INSTALL_PREFIX:-/paddle/build}
-DWITH_GRPC=${grpc_flag} -DWITH_GRPC=${grpc_flag}
-DWITH_GLOO=${gloo_flag} -DWITH_PSCORE=${distibuted_flag}
-DWITH_GLOO=${gloo_flag}
-DWITH_LITE=${WITH_LITE:-OFF} -DWITH_LITE=${WITH_LITE:-OFF}
-DWITH_XPU=${WITH_XPU:-OFF} -DWITH_XPU=${WITH_XPU:-OFF}
-DLITE_GIT_TAG=develop -DLITE_GIT_TAG=develop
...@@ -269,7 +270,8 @@ EOF ...@@ -269,7 +270,8 @@ EOF
-DPY_VERSION=${PY_VERSION:-2.7} \ -DPY_VERSION=${PY_VERSION:-2.7} \
-DCMAKE_INSTALL_PREFIX=${INSTALL_PREFIX:-/paddle/build} \ -DCMAKE_INSTALL_PREFIX=${INSTALL_PREFIX:-/paddle/build} \
-DWITH_GRPC=${grpc_flag} \ -DWITH_GRPC=${grpc_flag} \
-DWITH_GLOO=${gloo_flag} \ -DWITH_PSCORE=${distibuted_flag} \
-DWITH_GLOO=${gloo_flag} \
-DLITE_GIT_TAG=develop \ -DLITE_GIT_TAG=develop \
-DWITH_XPU=${WITH_XPU:-OFF} \ -DWITH_XPU=${WITH_XPU:-OFF} \
-DWITH_LITE=${WITH_LITE:-OFF};build_error=$? -DWITH_LITE=${WITH_LITE:-OFF};build_error=$?
......
...@@ -59,7 +59,8 @@ int main(int argc, char** argv) { ...@@ -59,7 +59,8 @@ int main(int argc, char** argv) {
std::vector<std::string> envs; std::vector<std::string> envs;
std::vector<std::string> undefok; std::vector<std::string> undefok;
#if defined(PADDLE_WITH_DISTRIBUTE) && !defined(PADDLE_WITH_GRPC) #if defined(PADDLE_WITH_DISTRIBUTE) && !defined(PADDLE_WITH_GRPC) && \
!defined(PADDLE_WITH_PSLIB)
std::string str_max_body_size; std::string str_max_body_size;
if (google::GetCommandLineOption("max_body_size", &str_max_body_size)) { if (google::GetCommandLineOption("max_body_size", &str_max_body_size)) {
setenv("FLAGS_max_body_size", "2147483647", 1); setenv("FLAGS_max_body_size", "2147483647", 1);
......
...@@ -268,7 +268,7 @@ class Service: ...@@ -268,7 +268,7 @@ class Service:
def __init__(self): def __init__(self):
self.server_class = "BrpcPsServer" self.server_class = "BrpcPsServer"
self.client_class = "BrpcPsClient" self.client_class = "BrpcPsClient"
self.service_class = "PsService" self.service_class = "BrpcPsService"
self.start_server_port = 0 self.start_server_port = 0
self.server_thread_num = 12 self.server_thread_num = 12
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册