未验证 提交 25f80fd3 编写于 作者: T tangwei12 提交者: GitHub

Fix/distributed proto (#29981)

* rename sendrecv.proto to namespace paddle.distributed

* split ps with distributed
上级 d479ae17
......@@ -160,6 +160,7 @@ option(WITH_BOX_PS "Compile with box_ps support" OFF)
option(WITH_XBYAK "Compile with xbyak support" ON)
option(WITH_CONTRIB "Compile the third-party contributation" OFF)
option(WITH_GRPC "Use grpc as the default rpc framework" ${WITH_DISTRIBUTE})
option(WITH_PSCORE "Compile with parameter server support" ${WITH_DISTRIBUTE})
option(WITH_INFERENCE_API_TEST "Test fluid inference C++ high-level api interface" OFF)
option(PY_VERSION "Compile PaddlePaddle with python3 support" ${PY_VERSION})
option(WITH_DGC "Use DGC(Deep Gradient Compression) or not" ${WITH_DISTRIBUTE})
......
......@@ -160,6 +160,11 @@ if(WITH_DISTRIBUTE)
add_definitions(-DPADDLE_WITH_DISTRIBUTE)
endif()
if(WITH_PSCORE)
add_definitions(-DPADDLE_WITH_PSCORE)
endif()
if(WITH_GRPC)
add_definitions(-DPADDLE_WITH_GRPC)
endif(WITH_GRPC)
......
......@@ -274,7 +274,7 @@ if(WITH_BOX_PS)
list(APPEND third_party_deps extern_box_ps)
endif(WITH_BOX_PS)
if (WITH_DISTRIBUTE)
if (WITH_PSCORE)
include(external/snappy)
list(APPEND third_party_deps extern_snappy)
......
if (WITH_PSLIB)
return()
endif()
if(NOT WITH_DISTRIBUTE)
if(NOT WITH_PSCORE)
return()
endif()
......
......@@ -69,24 +69,24 @@ class ObjectFactory {
};
typedef std::map<std::string, ObjectFactory *> FactoryMap;
typedef std::map<std::string, FactoryMap> BaseClassMap;
typedef std::map<std::string, FactoryMap> PsCoreClassMap;
#ifdef __cplusplus
extern "C" {
#endif
inline BaseClassMap &global_factory_map() {
static BaseClassMap *base_class = new BaseClassMap();
inline PsCoreClassMap &global_factory_map() {
static PsCoreClassMap *base_class = new PsCoreClassMap();
return *base_class;
}
#ifdef __cplusplus
}
#endif
inline BaseClassMap &global_factory_map_cpp() { return global_factory_map(); }
inline PsCoreClassMap &global_factory_map_cpp() { return global_factory_map(); }
// typedef pa::Any Any;
// typedef ::FactoryMap FactoryMap;
#define REGISTER_REGISTERER(base_class) \
#define REGISTER_PSCORE_REGISTERER(base_class) \
class base_class##Registerer { \
public: \
static base_class *CreateInstanceByName(const ::std::string &name) { \
......@@ -107,7 +107,7 @@ inline BaseClassMap &global_factory_map_cpp() { return global_factory_map(); }
} \
};
#define REGISTER_CLASS(clazz, name) \
#define REGISTER_PSCORE_CLASS(clazz, name) \
class ObjectFactory##name : public ObjectFactory { \
public: \
Any NewInstance() { return Any(new name()); } \
......@@ -120,7 +120,7 @@ inline BaseClassMap &global_factory_map_cpp() { return global_factory_map(); }
} \
void register_factory_##name() __attribute__((constructor));
#define CREATE_CLASS(base_class, name) \
#define CREATE_PSCORE_CLASS(base_class, name) \
base_class##Registerer::CreateInstanceByName(name);
} // namespace distributed
......
......@@ -86,7 +86,7 @@ message SparseTableParameter {
message ServerServiceParameter {
optional string server_class = 1 [ default = "BrpcPsServer" ];
optional string client_class = 2 [ default = "BrpcPsClient" ];
optional string service_class = 3 [ default = "PsService" ];
optional string service_class = 3 [ default = "BrpcPsService" ];
optional uint32 start_server_port = 4
[ default = 0 ]; // will find a avaliable port from it
optional uint32 server_thread_num = 5 [ default = 12 ];
......
......@@ -17,8 +17,8 @@
#include <sstream>
#include <string>
#include <vector>
#include "Eigen/Dense"
#include "paddle/fluid/distributed/service/brpc_ps_client.h"
#include "paddle/fluid/distributed/table/table.h"
#include "paddle/fluid/framework/archive.h"
......@@ -80,8 +80,8 @@ inline size_t get_sparse_shard(uint32_t shard_num, uint32_t server_num,
void DownpourPsClientService::service(
::google::protobuf::RpcController *controller,
const ::paddle::PsRequestMessage *request,
::paddle::PsResponseMessage *response, ::google::protobuf::Closure *done) {
const PsRequestMessage *request, PsResponseMessage *response,
::google::protobuf::Closure *done) {
brpc::ClosureGuard done_guard(done);
int ret = _client->handle_client2client_msg(
request->cmd_id(), request->client_id(), request->data());
......
......@@ -40,8 +40,8 @@ class DownpourPsClientService : public PsService {
return 0;
}
virtual void service(::google::protobuf::RpcController *controller,
const ::paddle::PsRequestMessage *request,
::paddle::PsResponseMessage *response,
const PsRequestMessage *request,
PsResponseMessage *response,
::google::protobuf::Closure *done) override;
protected:
......
......@@ -13,6 +13,7 @@
// limitations under the License.
#include "paddle/fluid/distributed/service/brpc_ps_server.h"
#include <thread> // NOLINT
#include "Eigen/Dense"
#include "butil/endpoint.h"
......@@ -30,7 +31,8 @@ int32_t BrpcPsServer::initialize() {
LOG(ERROR) << "miss service_class in ServerServiceParameter";
return -1;
}
auto *service = CREATE_CLASS(PsBaseService, service_config.service_class());
auto *service =
CREATE_PSCORE_CLASS(PsBaseService, service_config.service_class());
if (service == NULL) {
LOG(ERROR) << "service is unregistered, service_name:"
<< service_config.service_class();
......@@ -79,28 +81,28 @@ uint64_t BrpcPsServer::start(const std::string &ip, uint32_t port) {
int32_t BrpcPsServer::port() { return _server.listen_address().port; }
int32_t PsService::initialize() {
int32_t BrpcPsService::initialize() {
_is_initialize_shard_info = false;
_service_handler_map[PS_STOP_SERVER] = &PsService::stop_server;
_service_handler_map[PS_PULL_DENSE_TABLE] = &PsService::pull_dense;
_service_handler_map[PS_PUSH_DENSE_TABLE] = &PsService::push_dense;
_service_handler_map[PS_PULL_SPARSE_TABLE] = &PsService::pull_sparse;
_service_handler_map[PS_PUSH_SPARSE_TABLE] = &PsService::push_sparse;
_service_handler_map[PS_SAVE_ONE_TABLE] = &PsService::save_one_table;
_service_handler_map[PS_SAVE_ALL_TABLE] = &PsService::save_all_table;
_service_handler_map[PS_SHRINK_TABLE] = &PsService::shrink_table;
_service_handler_map[PS_LOAD_ONE_TABLE] = &PsService::load_one_table;
_service_handler_map[PS_LOAD_ALL_TABLE] = &PsService::load_all_table;
_service_handler_map[PS_CLEAR_ONE_TABLE] = &PsService::clear_one_table;
_service_handler_map[PS_CLEAR_ALL_TABLE] = &PsService::clear_all_table;
_service_handler_map[PS_PUSH_DENSE_PARAM] = &PsService::push_dense_param;
_service_handler_map[PS_PRINT_TABLE_STAT] = &PsService::print_table_stat;
_service_handler_map[PS_PULL_GEO_PARAM] = &PsService::pull_geo_param;
_service_handler_map[PS_PUSH_SPARSE_PARAM] = &PsService::push_sparse_param;
_service_handler_map[PS_BARRIER] = &PsService::barrier;
_service_handler_map[PS_START_PROFILER] = &PsService::start_profiler;
_service_handler_map[PS_STOP_PROFILER] = &PsService::stop_profiler;
_service_handler_map[PS_PUSH_GLOBAL_STEP] = &PsService::push_global_step;
_service_handler_map[PS_STOP_SERVER] = &BrpcPsService::stop_server;
_service_handler_map[PS_PULL_DENSE_TABLE] = &BrpcPsService::pull_dense;
_service_handler_map[PS_PUSH_DENSE_TABLE] = &BrpcPsService::push_dense;
_service_handler_map[PS_PULL_SPARSE_TABLE] = &BrpcPsService::pull_sparse;
_service_handler_map[PS_PUSH_SPARSE_TABLE] = &BrpcPsService::push_sparse;
_service_handler_map[PS_SAVE_ONE_TABLE] = &BrpcPsService::save_one_table;
_service_handler_map[PS_SAVE_ALL_TABLE] = &BrpcPsService::save_all_table;
_service_handler_map[PS_SHRINK_TABLE] = &BrpcPsService::shrink_table;
_service_handler_map[PS_LOAD_ONE_TABLE] = &BrpcPsService::load_one_table;
_service_handler_map[PS_LOAD_ALL_TABLE] = &BrpcPsService::load_all_table;
_service_handler_map[PS_CLEAR_ONE_TABLE] = &BrpcPsService::clear_one_table;
_service_handler_map[PS_CLEAR_ALL_TABLE] = &BrpcPsService::clear_all_table;
_service_handler_map[PS_PUSH_DENSE_PARAM] = &BrpcPsService::push_dense_param;
_service_handler_map[PS_PRINT_TABLE_STAT] = &BrpcPsService::print_table_stat;
_service_handler_map[PS_PULL_GEO_PARAM] = &BrpcPsService::pull_geo_param;
_service_handler_map[PS_PUSH_SPARSE_PARAM] =
&BrpcPsService::push_sparse_param;
_service_handler_map[PS_BARRIER] = &BrpcPsService::barrier;
_service_handler_map[PS_START_PROFILER] = &BrpcPsService::start_profiler;
_service_handler_map[PS_STOP_PROFILER] = &BrpcPsService::stop_profiler;
// shard初始化,server启动后才可从env获取到server_list的shard信息
initialize_shard_info();
......@@ -116,7 +118,7 @@ int32_t PsService::initialize() {
return -1; \
}
int32_t PsService::initialize_shard_info() {
int32_t BrpcPsService::initialize_shard_info() {
if (!_is_initialize_shard_info) {
std::lock_guard<std::mutex> guard(_initialize_shard_mutex);
if (_is_initialize_shard_info) {
......@@ -132,10 +134,10 @@ int32_t PsService::initialize_shard_info() {
return 0;
}
void PsService::service(google::protobuf::RpcController *cntl_base,
const PsRequestMessage *request,
PsResponseMessage *response,
google::protobuf::Closure *done) {
void BrpcPsService::service(google::protobuf::RpcController *cntl_base,
const PsRequestMessage *request,
PsResponseMessage *response,
google::protobuf::Closure *done) {
brpc::ClosureGuard done_guard(done);
std::string log_label("ReceiveCmd-");
if (!request->has_table_id()) {
......@@ -163,9 +165,9 @@ void PsService::service(google::protobuf::RpcController *cntl_base,
}
}
int32_t PsService::pull_dense(Table *table, const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
int32_t BrpcPsService::pull_dense(Table *table, const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
platform::RecordEvent record_event("PsService->pull_dense");
CHECK_TABLE_EXIST(table, request, response)
if (request.params_size() < 1) {
......@@ -191,10 +193,10 @@ int32_t PsService::pull_dense(Table *table, const PsRequestMessage &request,
return 0;
}
int32_t PsService::push_dense_param(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
int32_t BrpcPsService::push_dense_param(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
platform::RecordEvent record_event("PsService->push_dense_param");
CHECK_TABLE_EXIST(table, request, response)
thread_local std::string push_buffer;
......@@ -218,9 +220,9 @@ int32_t PsService::push_dense_param(Table *table,
return 0;
}
int32_t PsService::push_dense(Table *table, const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
int32_t BrpcPsService::push_dense(Table *table, const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
platform::RecordEvent record_event("PsService->push_dense");
CHECK_TABLE_EXIST(table, request, response)
auto req_buffer_size = request.data().size();
......@@ -244,9 +246,9 @@ int32_t PsService::push_dense(Table *table, const PsRequestMessage &request,
return 0;
}
int32_t PsService::barrier(Table *table, const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
int32_t BrpcPsService::barrier(Table *table, const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
CHECK_TABLE_EXIST(table, request, response)
if (request.params_size() < 1) {
......@@ -262,10 +264,10 @@ int32_t PsService::barrier(Table *table, const PsRequestMessage &request,
return 0;
}
int32_t PsService::push_sparse_param(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
int32_t BrpcPsService::push_sparse_param(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
platform::RecordEvent record_event("PsService->push_sparse_param");
CHECK_TABLE_EXIST(table, request, response)
auto &push_data = request.data();
......@@ -294,9 +296,10 @@ int32_t PsService::push_sparse_param(Table *table,
return 0;
}
int32_t PsService::pull_geo_param(Table *table, const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
int32_t BrpcPsService::pull_geo_param(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
platform::RecordEvent record_event("PsService->pull_geo_param");
CHECK_TABLE_EXIST(table, request, response)
thread_local std::string push_sparse_request_buffer;
......@@ -316,9 +319,10 @@ int32_t PsService::pull_geo_param(Table *table, const PsRequestMessage &request,
return 0;
}
int32_t PsService::pull_sparse(Table *table, const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
int32_t BrpcPsService::pull_sparse(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
platform::RecordEvent record_event("PsService->pull_sparse");
CHECK_TABLE_EXIST(table, request, response)
thread_local std::string push_sparse_request_buffer;
......@@ -353,9 +357,10 @@ int32_t PsService::pull_sparse(Table *table, const PsRequestMessage &request,
return 0;
}
int32_t PsService::push_sparse(Table *table, const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
int32_t BrpcPsService::push_sparse(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
platform::RecordEvent record_event("PsService->push_sparse");
CHECK_TABLE_EXIST(table, request, response)
auto &push_data = request.data();
......@@ -384,10 +389,10 @@ int32_t PsService::push_sparse(Table *table, const PsRequestMessage &request,
return 0;
}
int32_t PsService::print_table_stat(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
int32_t BrpcPsService::print_table_stat(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
CHECK_TABLE_EXIST(table, request, response)
std::pair<int64_t, int64_t> ret = table->print_table_stat();
paddle::framework::BinaryArchive ar;
......@@ -398,9 +403,10 @@ int32_t PsService::print_table_stat(Table *table,
return 0;
}
int32_t PsService::load_one_table(Table *table, const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
int32_t BrpcPsService::load_one_table(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
CHECK_TABLE_EXIST(table, request, response)
if (request.params_size() < 2) {
set_response_code(
......@@ -415,9 +421,10 @@ int32_t PsService::load_one_table(Table *table, const PsRequestMessage &request,
return 0;
}
int32_t PsService::load_all_table(Table *table, const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
int32_t BrpcPsService::load_all_table(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
auto &table_map = *(_server->table());
for (auto &itr : table_map) {
if (load_one_table(itr.second.get(), request, response, cntl) != 0) {
......@@ -428,9 +435,10 @@ int32_t PsService::load_all_table(Table *table, const PsRequestMessage &request,
return 0;
}
int32_t PsService::save_one_table(Table *table, const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
int32_t BrpcPsService::save_one_table(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
CHECK_TABLE_EXIST(table, request, response)
if (request.params_size() < 2) {
set_response_code(
......@@ -449,9 +457,10 @@ int32_t PsService::save_one_table(Table *table, const PsRequestMessage &request,
return feasign_size;
}
int32_t PsService::save_all_table(Table *table, const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
int32_t BrpcPsService::save_all_table(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
auto &table_map = *(_server->table());
int32_t all_feasign_size = 0;
int32_t feasign_size = 0;
......@@ -466,9 +475,10 @@ int32_t PsService::save_all_table(Table *table, const PsRequestMessage &request,
return 0;
}
int32_t PsService::shrink_table(Table *table, const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
int32_t BrpcPsService::shrink_table(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
CHECK_TABLE_EXIST(table, request, response)
table->flush();
if (table->shrink() != 0) {
......@@ -477,20 +487,20 @@ int32_t PsService::shrink_table(Table *table, const PsRequestMessage &request,
return 0;
}
int32_t PsService::clear_one_table(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
int32_t BrpcPsService::clear_one_table(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
CHECK_TABLE_EXIST(table, request, response)
table->flush();
table->clear();
return 0;
}
int32_t PsService::clear_all_table(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
int32_t BrpcPsService::clear_all_table(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
auto &table_map = *(_server->table());
for (auto &itr : table_map) {
if (clear_one_table(itr.second.get(), request, response, cntl) != 0) {
......@@ -500,9 +510,10 @@ int32_t PsService::clear_all_table(Table *table,
return 0;
}
int32_t PsService::stop_server(Table *table, const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
int32_t BrpcPsService::stop_server(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
auto *p_server = _server;
std::thread t_stop([p_server]() {
p_server->stop();
......@@ -512,25 +523,27 @@ int32_t PsService::stop_server(Table *table, const PsRequestMessage &request,
return 0;
}
int32_t PsService::stop_profiler(Table *table, const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
int32_t BrpcPsService::stop_profiler(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
platform::DisableProfiler(platform::EventSortingKey::kDefault,
string::Sprintf("server_%s_profile", _rank));
return 0;
}
int32_t PsService::start_profiler(Table *table, const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
int32_t BrpcPsService::start_profiler(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
platform::EnableProfiler(platform::ProfilerState::kCPU);
return 0;
}
int32_t PsService::push_global_step(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
int32_t BrpcPsService::push_global_step(Table *table,
const PsRequestMessage &request,
PsResponseMessage &response,
brpc::Controller *cntl) {
CHECK_TABLE_EXIST(table, request, response);
auto req_buffer_size = request.data().size();
if (req_buffer_size < 1) {
......
......@@ -52,19 +52,19 @@ class BrpcPsServer : public PSServer {
std::vector<std::shared_ptr<brpc::Channel>> _pserver_channels;
};
class PsService;
class BrpcPsService;
typedef int32_t (PsService::*serviceHandlerFunc)(
typedef int32_t (BrpcPsService::*serviceHandlerFunc)(
Table *table, const PsRequestMessage &request, PsResponseMessage &response,
brpc::Controller *cntl);
class PsService : public PsBaseService {
class BrpcPsService : public PsBaseService {
public:
virtual int32_t initialize() override;
virtual void service(::google::protobuf::RpcController *controller,
const ::paddle::PsRequestMessage *request,
::paddle::PsResponseMessage *response,
const PsRequestMessage *request,
PsResponseMessage *response,
::google::protobuf::Closure *done) override;
private:
......
......@@ -88,7 +88,7 @@ void SerializeLodTensor(framework::Variable* var,
const platform::DeviceContext& ctx, VarMsg* var_msg,
butil::IOBuf* iobuf) {
auto* tensor = var->GetMutable<framework::LoDTensor>();
var_msg->set_type(::paddle::LOD_TENSOR);
var_msg->set_type(::paddle::distributed::LOD_TENSOR);
const framework::LoD lod = tensor->lod();
if (lod.size() > 0) {
var_msg->set_lod_level(lod.size());
......@@ -135,7 +135,7 @@ void SerializeSelectedRows(framework::Variable* var,
auto* tensor = slr->mutable_value();
auto* rows = slr->mutable_rows();
var_msg->set_type(::paddle::SELECTED_ROWS);
var_msg->set_type(::paddle::distributed::SELECTED_ROWS);
var_msg->set_slr_height(slr->height());
auto* var_data = var_msg->mutable_data();
......@@ -194,9 +194,9 @@ void DeserializeFromMultiVarMsgAndIOBuf(const MultiVarMsg& multi_msg,
++recv_var_index) {
const auto& msg = multi_msg.var_messages(recv_var_index);
auto* var = scope->Var(msg.varname());
if (msg.type() == ::paddle::LOD_TENSOR) {
if (msg.type() == ::paddle::distributed::LOD_TENSOR) {
DeserializeLodTensor(var, msg, io_buffer_itr, ctx);
} else if (msg.type() == ::paddle::SELECTED_ROWS) {
} else if (msg.type() == ::paddle::distributed::SELECTED_ROWS) {
DeserializeSelectedRows(var, msg, io_buffer_itr, ctx);
}
}
......@@ -215,9 +215,9 @@ void DeserializeFromMultiVarMsgAndIOBuf(const MultiVarMsg& multi_msg,
PADDLE_ENFORCE_NE(var, nullptr,
platform::errors::InvalidArgument(
"Not find variable %s in scope.", msg.varname()));
if (msg.type() == ::paddle::LOD_TENSOR) {
if (msg.type() == ::paddle::distributed::LOD_TENSOR) {
DeserializeLodTensor(var, msg, io_buffer_itr, ctx);
} else if (msg.type() == ::paddle::SELECTED_ROWS) {
} else if (msg.type() == ::paddle::distributed::SELECTED_ROWS) {
DeserializeSelectedRows(var, msg, io_buffer_itr, ctx);
}
}
......
......@@ -44,8 +44,8 @@ class DeviceContext;
namespace paddle {
namespace distributed {
using MultiVarMsg = ::paddle::MultiVariableMessage;
using VarMsg = ::paddle::VariableMessage;
using MultiVarMsg = ::paddle::distributed::MultiVariableMessage;
using VarMsg = ::paddle::distributed::VariableMessage;
void SerializeToMultiVarMsgAndIOBuf(
const std::string& message_name,
......
......@@ -122,7 +122,7 @@ void HeterClient::SendAndRecvAsync(
cntl.set_timeout_ms(FLAGS_pserver_timeout_ms);
distributed::MultiVarMsg request, response;
auto& request_io_buffer = cntl.request_attachment();
::paddle::PsService_Stub stub(xpu_channels_[num].get());
::paddle::distributed::PsService_Stub stub(xpu_channels_[num].get());
distributed::SerializeToMultiVarMsgAndIOBuf(
message_name_val, send_var_name_val, recv_var_name_val, *p_ctx, p_scope,
&request, &request_io_buffer);
......@@ -164,7 +164,7 @@ std::future<int32_t> HeterClient::SendCmd(
for (const auto& param : params) {
closure->request(i)->add_params(param);
}
::paddle::PsService_Stub rpc_stub(xpu_channels_[i].get());
::paddle::distributed::PsService_Stub rpc_stub(xpu_channels_[i].get());
closure->cntl(i)->set_timeout_ms(
FLAGS_pserver_timeout_ms); // cmd msg don't limit timeout for save/load
rpc_stub.service(closure->cntl(i), closure->request(i),
......
......@@ -35,8 +35,8 @@ limitations under the License. */
namespace paddle {
namespace distributed {
using MultiVarMsg = ::paddle::MultiVariableMessage;
using VarMsg = ::paddle::VariableMessage;
using MultiVarMsg = ::paddle::distributed::MultiVariableMessage;
using VarMsg = ::paddle::distributed::VariableMessage;
typedef std::function<void(void*)> HeterRpcCallbackFunc;
......
......@@ -39,8 +39,8 @@ DECLARE_double(eager_delete_tensor_gb);
namespace paddle {
namespace distributed {
using MultiVarMsg = ::paddle::MultiVariableMessage;
using VarMsg = ::paddle::VariableMessage;
using MultiVarMsg = ::paddle::distributed::MultiVariableMessage;
using VarMsg = ::paddle::distributed::VariableMessage;
class HeterService;
typedef int32_t (HeterService::*serviceHandlerFunc)(
......@@ -51,7 +51,7 @@ typedef std::function<void(void*)> HeterRpcCallbackFunc;
typedef std::function<int(const MultiVarMsg*, MultiVarMsg*, brpc::Controller*)>
HeterServiceHandler;
class HeterService : public ::paddle::PsService {
class HeterService : public ::paddle::distributed::PsService {
public:
HeterService() {
_service_handler_map[PS_STOP_SERVER] = &HeterService::stop_heter_worker;
......@@ -62,8 +62,8 @@ class HeterService : public ::paddle::PsService {
virtual ~HeterService() {}
virtual void service(::google::protobuf::RpcController* controller,
const ::paddle::PsRequestMessage* request,
::paddle::PsResponseMessage* response,
const PsRequestMessage* request,
PsResponseMessage* response,
::google::protobuf::Closure* done) {
brpc::ClosureGuard done_guard(done);
std::string log_label("ReceiveCmd-");
......
......@@ -13,9 +13,7 @@
// limitations under the License.
#include "paddle/fluid/distributed/service/ps_client.h"
#include <map>
#include "brpc/server.h"
#include "glog/logging.h"
#include "paddle/fluid/distributed/service/brpc_ps_client.h"
......@@ -23,7 +21,7 @@
namespace paddle {
namespace distributed {
REGISTER_CLASS(PSClient, BrpcPsClient);
REGISTER_PSCORE_CLASS(PSClient, BrpcPsClient);
int32_t PSClient::configure(
const PSParameter &config,
......@@ -43,7 +41,7 @@ int32_t PSClient::configure(
const auto &work_param = _config.worker_param().downpour_worker_param();
for (size_t i = 0; i < work_param.downpour_table_param_size(); ++i) {
auto *accessor = CREATE_CLASS(
auto *accessor = CREATE_PSCORE_CLASS(
ValueAccessor,
work_param.downpour_table_param(i).accessor().accessor_class());
accessor->configure(work_param.downpour_table_param(i).accessor());
......@@ -73,7 +71,8 @@ PSClient *PSClientFactory::create(const PSParameter &ps_config) {
}
const auto &service_param = config.downpour_server_param().service_param();
PSClient *client = CREATE_CLASS(PSClient, service_param.client_class());
PSClient *client =
CREATE_PSCORE_CLASS(PSClient, service_param.client_class());
if (client == NULL) {
LOG(ERROR) << "client is not registered, server_name:"
<< service_param.client_class();
......
......@@ -28,6 +28,9 @@
namespace paddle {
namespace distributed {
using paddle::distributed::PsRequestMessage;
using paddle::distributed::PsResponseMessage;
typedef std::function<void(void *)> PSClientCallBack;
class PSClientClosure : public google::protobuf::Closure {
public:
......@@ -206,7 +209,7 @@ class PSClient {
std::unordered_map<int32_t, MsgHandlerFunc>
_msg_handler_map; //处理client2client消息
};
REGISTER_REGISTERER(PSClient);
REGISTER_PSCORE_REGISTERER(PSClient);
class PSClientFactory {
public:
......
......@@ -13,7 +13,7 @@
// limitations under the License.
syntax = "proto2";
package paddle;
package paddle.distributed;
option cc_generic_services = true;
option cc_enable_arenas = true;
......
......@@ -13,6 +13,7 @@
// limitations under the License.
#include "paddle/fluid/distributed/service/server.h"
#include "glog/logging.h"
#include "paddle/fluid/distributed/service/brpc_ps_server.h"
#include "paddle/fluid/distributed/table/table.h"
......@@ -20,8 +21,8 @@
namespace paddle {
namespace distributed {
REGISTER_CLASS(PSServer, BrpcPsServer);
REGISTER_CLASS(PsBaseService, PsService);
REGISTER_PSCORE_CLASS(PSServer, BrpcPsServer);
REGISTER_PSCORE_CLASS(PsBaseService, BrpcPsService);
PSServer *PSServerFactory::create(const PSParameter &ps_config) {
const auto &config = ps_config.server_param();
......@@ -43,7 +44,8 @@ PSServer *PSServerFactory::create(const PSParameter &ps_config) {
}
const auto &service_param = config.downpour_server_param().service_param();
PSServer *server = CREATE_CLASS(PSServer, service_param.server_class());
PSServer *server =
CREATE_PSCORE_CLASS(PSServer, service_param.server_class());
if (server == NULL) {
LOG(ERROR) << "server is not registered, server_name:"
<< service_param.server_class();
......@@ -70,7 +72,7 @@ int32_t PSServer::configure(
uint32_t global_step_table = UINT32_MAX;
for (size_t i = 0; i < downpour_param.downpour_table_param_size(); ++i) {
auto *table = CREATE_CLASS(
auto *table = CREATE_PSCORE_CLASS(
Table, downpour_param.downpour_table_param(i).table_class());
if (downpour_param.downpour_table_param(i).table_class() ==
......
......@@ -46,6 +46,8 @@ namespace paddle {
namespace distributed {
class Table;
using paddle::distributed::PsRequestMessage;
using paddle::distributed::PsResponseMessage;
class PSServer {
public:
......@@ -107,7 +109,7 @@ class PSServer {
platform::Place place_ = platform::CPUPlace();
};
REGISTER_REGISTERER(PSServer);
REGISTER_PSCORE_REGISTERER(PSServer);
typedef std::function<void(void *)> PServerCallBack;
......@@ -141,8 +143,8 @@ class PsBaseService : public PsService {
return 0;
}
virtual void service(::google::protobuf::RpcController *controller,
const ::paddle::PsRequestMessage *request,
::paddle::PsResponseMessage *response,
const PsRequestMessage *request,
PsResponseMessage *response,
::google::protobuf::Closure *done) override = 0;
virtual void set_response_code(PsResponseMessage &response, int err_code,
......@@ -159,7 +161,7 @@ class PsBaseService : public PsService {
PSServer *_server;
const ServerParameter *_config;
};
REGISTER_REGISTERER(PsBaseService);
REGISTER_PSCORE_REGISTERER(PsBaseService);
class PSServerFactory {
public:
......
......@@ -28,6 +28,10 @@ limitations under the License. */
namespace paddle {
namespace distributed {
using paddle::distributed::PsRequestMessage;
using paddle::distributed::PsResponseMessage;
using paddle::distributed::PsService;
class PSCore {
public:
explicit PSCore() {}
......
......@@ -165,6 +165,6 @@ class ValueAccessor {
std::unordered_map<int, std::shared_ptr<struct DataConverter>>
_data_coverter_map;
};
REGISTER_REGISTERER(ValueAccessor);
REGISTER_PSCORE_REGISTERER(ValueAccessor);
} // namespace distributed
} // namespace paddle
......@@ -13,6 +13,7 @@
// limitations under the License.
#include "paddle/fluid/distributed/table/table.h"
#include <boost/preprocessor/repetition/repeat_from_to.hpp>
#include <boost/preprocessor/seq/elem.hpp>
#include "glog/logging.h"
......@@ -27,14 +28,14 @@
namespace paddle {
namespace distributed {
REGISTER_CLASS(Table, CommonDenseTable);
REGISTER_CLASS(Table, CommonSparseTable);
REGISTER_CLASS(Table, SparseGeoTable);
REGISTER_CLASS(Table, BarrierTable);
REGISTER_CLASS(Table, TensorTable);
REGISTER_CLASS(Table, DenseTensorTable);
REGISTER_CLASS(Table, GlobalStepTable);
REGISTER_CLASS(ValueAccessor, CommMergeAccessor);
REGISTER_PSCORE_CLASS(Table, CommonDenseTable);
REGISTER_PSCORE_CLASS(Table, CommonSparseTable);
REGISTER_PSCORE_CLASS(Table, SparseGeoTable);
REGISTER_PSCORE_CLASS(Table, BarrierTable);
REGISTER_PSCORE_CLASS(Table, TensorTable);
REGISTER_PSCORE_CLASS(Table, DenseTensorTable);
REGISTER_PSCORE_CLASS(Table, GlobalStepTable);
REGISTER_PSCORE_CLASS(ValueAccessor, CommMergeAccessor);
int32_t TableManager::initialize() {
static bool initialized = false;
......@@ -61,9 +62,9 @@ int32_t Table::initialize_accessor() {
<< _config.table_id();
return -1;
}
auto *accessor =
CREATE_CLASS(ValueAccessor,
_config.accessor().accessor_class()) if (accessor == NULL) {
auto *accessor = CREATE_PSCORE_CLASS(
ValueAccessor,
_config.accessor().accessor_class()) if (accessor == NULL) {
LOG(ERROR) << "accessor is unregisteg, table_id:" << _config.table_id()
<< ", accessor_name:" << _config.accessor().accessor_class();
return -1;
......
......@@ -127,7 +127,7 @@ class Table {
float *_global_lr = nullptr;
std::shared_ptr<ValueAccessor> _value_accesor;
};
REGISTER_REGISTERER(Table);
REGISTER_PSCORE_REGISTERER(Table);
class TableManager {
public:
......
......@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include <unistd.h>
#include <condition_variable> // NOLINT
#include <string>
#include <thread> // NOLINT
......@@ -94,7 +95,7 @@ void GetDownpourDenseTableProto(
server_proto->mutable_downpour_server_param();
::paddle::distributed::ServerServiceParameter* server_service_proto =
downpour_server_proto->mutable_service_param();
server_service_proto->set_service_class("PsService");
server_service_proto->set_service_class("BrpcPsService");
server_service_proto->set_server_class("BrpcPsServer");
server_service_proto->set_client_class("BrpcPsClient");
server_service_proto->set_start_server_port(0);
......@@ -124,7 +125,7 @@ void GetDownpourDenseTableProto(
server_proto->mutable_downpour_server_param();
::paddle::distributed::ServerServiceParameter* server_service_proto =
downpour_server_proto->mutable_service_param();
server_service_proto->set_service_class("PsService");
server_service_proto->set_service_class("BrpcPsService");
server_service_proto->set_server_class("BrpcPsServer");
server_service_proto->set_client_class("BrpcPsClient");
server_service_proto->set_start_server_port(0);
......@@ -244,7 +245,8 @@ void RunBrpcPushDense() {
int ret = 0;
auto* closure = (paddle::distributed::DownpourBrpcClosure*)done;
for (size_t i = 0; i < 1; ++i) {
if (closure->check_response(i, paddle::PS_PUSH_DENSE_TABLE) != 0) {
if (closure->check_response(
i, paddle::distributed::PS_PUSH_DENSE_TABLE) != 0) {
ret = -1;
break;
}
......
......@@ -18,6 +18,7 @@ limitations under the License. */
#include <thread> // NOLINT
#include "google/protobuf/text_format.h"
#include "gtest/gtest.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h"
......@@ -94,7 +95,7 @@ void GetDownpourSparseTableProto(
server_proto->mutable_downpour_server_param();
::paddle::distributed::ServerServiceParameter* server_service_proto =
downpour_server_proto->mutable_service_param();
server_service_proto->set_service_class("PsService");
server_service_proto->set_service_class("BrpcPsService");
server_service_proto->set_server_class("BrpcPsServer");
server_service_proto->set_client_class("BrpcPsClient");
server_service_proto->set_start_server_port(0);
......@@ -124,7 +125,7 @@ void GetDownpourSparseTableProto(
server_proto->mutable_downpour_server_param();
::paddle::distributed::ServerServiceParameter* server_service_proto =
downpour_server_proto->mutable_service_param();
server_service_proto->set_service_class("PsService");
server_service_proto->set_service_class("BrpcPsService");
server_service_proto->set_server_class("BrpcPsServer");
server_service_proto->set_client_class("BrpcPsClient");
server_service_proto->set_start_server_port(0);
......@@ -225,7 +226,8 @@ void RunBrpcPushSparse() {
int ret = 0;
auto* closure = (paddle::distributed::DownpourBrpcClosure*)done;
for (size_t i = 0; i < 1; ++i) {
if (closure->check_response(i, paddle::PS_PUSH_SPARSE_PARAM) != 0) {
if (closure->check_response(
i, paddle::distributed::PS_PUSH_SPARSE_PARAM) != 0) {
ret = -1;
break;
}
......@@ -252,7 +254,8 @@ void RunBrpcPushSparse() {
int ret = 0;
auto* closure = (paddle::distributed::DownpourBrpcClosure*)done;
for (size_t i = 0; i < 1; ++i) {
if (closure->check_response(i, paddle::PS_PUSH_SPARSE_TABLE) != 0) {
if (closure->check_response(
i, paddle::distributed::PS_PUSH_SPARSE_TABLE) != 0) {
ret = -1;
break;
}
......
......@@ -75,7 +75,7 @@ void RunMultiVarMsg(platform::Place place) {
auto& ctx = *pool.Get(place);
CreateVarsOnScope(&scope, &place, ctx);
::paddle::MultiVariableMessage multi_msg;
::paddle::distributed::MultiVariableMessage multi_msg;
std::string message_name("se_de_test");
std::vector<std::string> send_var_name = {"x1", "x2", "x3"};
std::vector<std::string> recv_var_name = {};
......@@ -138,4 +138,4 @@ TEST(MultiVarMsgCPU, Run) {
// platform::CUDAPlace place;
// RunMultiVarMsg(place);
// }
// #endif
\ No newline at end of file
// #endif
......@@ -209,12 +209,12 @@ if(WITH_DISTRIBUTE)
pull_dense_worker.cc section_worker.cc device_worker_factory.cc data_set.cc DEPS op_registry
device_context scope framework_proto trainer_desc_proto glog fs shell
fleet_wrapper heter_wrapper ps_gpu_wrapper box_wrapper lodtensor_printer
lod_rank_table feed_fetch_method sendrecvop_rpc communicator collective_helper ${GLOB_DISTRIBUTE_DEPS}
lod_rank_table feed_fetch_method collective_helper ${GLOB_DISTRIBUTE_DEPS}
graph_to_program_pass variable_helper data_feed_proto timer monitor
heter_service_proto pslib_brpc)
set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor")
set_source_files_properties(executor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
else()
elseif(WITH_PSCORE)
cc_library(executor SRCS executor.cc multi_trainer.cc pipeline_trainer.cc dataset_factory.cc
dist_multi_trainer.cc trainer_factory.cc trainer.cc data_feed_factory.cc
heterxpu_trainer.cc
......@@ -228,6 +228,16 @@ if(WITH_DISTRIBUTE)
set_source_files_properties(executor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(multi_trainer.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(hogwild_worker.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
else()
cc_library(executor SRCS executor.cc multi_trainer.cc pipeline_trainer.cc dataset_factory.cc
dist_multi_trainer.cc trainer_factory.cc trainer.cc data_feed_factory.cc
heterxpu_trainer.cc
data_feed.cc device_worker.cc hogwild_worker.cc hetercpu_worker.cc ps_gpu_worker.cc
heterbox_worker.cc heterbox_trainer.cc ps_gpu_trainer.cc downpour_worker.cc downpour_worker_opt.cc
pull_dense_worker.cc section_worker.cc device_worker_factory.cc data_set.cc DEPS op_registry
device_context scope framework_proto data_feed_proto heter_service_proto trainer_desc_proto glog
lod_rank_table fs shell fleet_wrapper heter_wrapper ps_gpu_wrapper box_wrapper lodtensor_printer feed_fetch_method
graph_to_program_pass variable_helper timer monitor)
endif()
elseif(WITH_PSLIB)
cc_library(executor SRCS executor.cc multi_trainer.cc pipeline_trainer.cc dataset_factory.cc
......@@ -239,7 +249,6 @@ elseif(WITH_PSLIB)
device_context scope framework_proto data_feed_proto heter_service_proto trainer_desc_proto glog
lod_rank_table fs shell fleet_wrapper heter_wrapper ps_gpu_wrapper box_wrapper lodtensor_printer feed_fetch_method
graph_to_program_pass variable_helper timer monitor pslib_brpc )
else()
cc_library(executor SRCS executor.cc multi_trainer.cc pipeline_trainer.cc dataset_factory.cc
dist_multi_trainer.cc trainer_factory.cc trainer.cc data_feed_factory.cc
......
......@@ -14,7 +14,7 @@ cc_library(multi_devices_helper SRCS multi_devices_helper.cc DEPS graph graph_he
cc_library(variable_visitor SRCS variable_visitor.cc DEPS lod_tensor selected_rows)
if(WITH_DISTRIBUTE)
if(WITH_PSCORE)
set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor")
set_source_files_properties(reduce_op_handle.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(threaded_ssa_graph_executor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
......
......@@ -16,7 +16,7 @@
#include "paddle/fluid/framework/variable_helper.h"
#ifdef PADDLE_WITH_DISTRIBUTE
#if defined PADDLE_WITH_PSCORE
#include "paddle/fluid/distributed/service/communicator.h"
#endif
......@@ -138,7 +138,7 @@ FetchResultType AsyncSSAGraphExecutor::Run(
"results to be fetched!"));
// init once
if (run_futures_.size() == 0 && places_.size() > 1) {
#ifdef PADDLE_WITH_DISTRIBUTE
#if defined PADDLE_WITH_PSCORE
if (strategy_.thread_barrier_) {
paddle::distributed::Communicator::GetInstance()->BarrierTriggerReset(
places_.size());
......
......@@ -17,7 +17,7 @@
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/platform/profiler.h"
#ifdef PADDLE_WITH_DISTRIBUTE
#if defined PADDLE_WITH_PSCORE
#include "paddle/fluid/distributed/service/communicator.h"
#endif
......@@ -360,7 +360,7 @@ bool ThreadedSSAGraphExecutor::RunOpSync(OpHandleBase *op) {
void ThreadedSSAGraphExecutor::ExecutionFinal(
std::vector<OpHandleBase *> *fetch_ops) {
#ifdef PADDLE_WITH_DISTRIBUTE
#if defined PADDLE_WITH_PSCORE
if (strategy_.thread_barrier_) {
paddle::distributed::Communicator::GetInstance()->BarrierTriggerDecrement();
}
......
......@@ -19,7 +19,7 @@ limitations under the License. */
#include "paddle/fluid/platform/cpu_helper.h"
#include "paddle/fluid/platform/lodtensor_printer.h"
#ifdef PADDLE_WITH_DISTRIBUTE
#if defined PADDLE_WITH_PSCORE
#include "paddle/fluid/distributed/service/communicator.h"
#endif
......@@ -186,7 +186,7 @@ void HogwildWorker::TrainFilesWithProfiler() {
writer_.Flush();
}
#ifdef PADDLE_WITH_DISTRIBUTE
#if defined PADDLE_WITH_PSCORE
if (thread_barrier_) {
paddle::distributed::Communicator::GetInstance()->BarrierTriggerDecrement();
}
......@@ -216,7 +216,7 @@ void HogwildWorker::TrainFiles() {
PrintFetchVars();
thread_scope_->DropKids();
}
#ifdef PADDLE_WITH_DISTRIBUTE
#if defined PADDLE_WITH_PSCORE
if (thread_barrier_) {
paddle::distributed::Communicator::GetInstance()->BarrierTriggerDecrement();
}
......
......@@ -18,7 +18,7 @@ limitations under the License. */
#include "paddle/fluid/framework/device_worker_factory.h"
#include "paddle/fluid/framework/trainer.h"
#ifdef PADDLE_WITH_DISTRIBUTE
#if defined PADDLE_WITH_PSCORE
#include "paddle/fluid/distributed/service/communicator.h"
#endif
......@@ -49,7 +49,7 @@ void MultiTrainer::Initialize(const TrainerDesc& trainer_desc,
VLOG(3) << "worker thread num: " << thread_num_;
workers_.resize(thread_num_);
#ifdef PADDLE_WITH_DISTRIBUTE
#if defined PADDLE_WITH_PSCORE
if (trainer_desc.thread_barrier()) {
paddle::distributed::Communicator::GetInstance()->BarrierTriggerReset(
thread_num_);
......
......@@ -77,12 +77,12 @@ set(SHARED_INFERENCE_SRCS
${mkldnn_quantizer_src_file})
# Create shared inference library defaultly
if(NOT WITH_DISTRIBUTE)
if(NOT WITH_PSCORE)
cc_library(paddle_fluid_shared SHARED SRCS ${SHARED_INFERENCE_SRCS}
DEPS ${fluid_modules} analysis_predictor)
DEPS ${fluid_modules} analysis_predictor)
else()
cc_library(paddle_fluid_shared SHARED SRCS ${SHARED_INFERENCE_SRCS}
DEPS ${fluid_modules} analysis_predictor fleet ps_service)
DEPS ${fluid_modules} analysis_predictor fleet ps_service)
endif()
get_property(os_dependency_modules GLOBAL PROPERTY OS_DEPENDENCY_MODULES)
......
......@@ -22,10 +22,13 @@ add_subdirectory(jit)
if(WITH_DISTRIBUTE)
add_subdirectory(pscore)
add_subdirectory(collective)
endif()
if (WITH_PSCORE)
add_subdirectory(pscore)
endif()
add_subdirectory(amp)
add_subdirectory(reader)
......
if (WITH_PSLIB)
return()
endif()
include(operators)
set(DISTRIBUTE_DEPS "")
......
......@@ -46,8 +46,8 @@ class DeviceContext;
namespace paddle {
namespace operators {
using MultiVarMsg = ::paddle::MultiVariableMessage;
using VarMsg = ::paddle::VariableMessage;
using MultiVarMsg = ::paddle::distributed::MultiVariableMessage;
using VarMsg = ::paddle::distributed::VariableMessage;
template <class TKey, class TValue>
class DoubleFindMap : public std::unordered_map<TKey, TValue> {
......
......@@ -36,8 +36,8 @@ namespace framework = paddle::framework;
namespace platform = paddle::platform;
namespace distributed = paddle::distributed;
using MultiVarMsg = ::paddle::MultiVariableMessage;
using VarMsg = ::paddle::VariableMessage;
using MultiVarMsg = ::paddle::distributed::MultiVariableMessage;
using VarMsg = ::paddle::distributed::VariableMessage;
DECLARE_double(eager_delete_tensor_gb);
USE_OP(scale);
......
......@@ -32,8 +32,8 @@ namespace framework = paddle::framework;
namespace platform = paddle::platform;
namespace distributed = paddle::distributed;
using MultiVarMsg = ::paddle::MultiVariableMessage;
using VarMsg = ::paddle::VariableMessage;
using MultiVarMsg = ::paddle::distributed::MultiVariableMessage;
using VarMsg = ::paddle::distributed::VariableMessage;
USE_OP(scale);
......
......@@ -49,7 +49,7 @@ if (WITH_CRYPTO)
set(PYBIND_SRCS ${PYBIND_SRCS} crypto.cc)
endif (WITH_CRYPTO)
if (WITH_DISTRIBUTE)
if (WITH_PSCORE)
set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor -Wno-error=sign-compare -Wno-error=unused-variable -Wno-error=return-type -Wno-error=unused-but-set-variable -Wno-error=type-limits -Wno-error=unknown-pragmas -Wno-error=parentheses -Wno-error=unused-result")
set_source_files_properties(fleet_py.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
list(APPEND PYBIND_DEPS fleet communicator)
......
......@@ -107,7 +107,7 @@ limitations under the License. */
#include "paddle/fluid/pybind/crypto.h"
#endif
#ifdef PADDLE_WITH_DISTRIBUTE
#if defined PADDLE_WITH_PSCORE
#include "paddle/fluid/pybind/fleet_py.h"
#endif
......@@ -2841,7 +2841,7 @@ All parameter, weight, gradient are variables in Paddle.
BindCrypto(&m);
#endif
#ifdef PADDLE_WITH_DISTRIBUTE
#if defined PADDLE_WITH_PSCORE
BindDistFleetWrapper(&m);
BindPSHost(&m);
BindCommunicatorContext(&m);
......
......@@ -249,6 +249,7 @@ function cmake_base() {
-DPY_VERSION=${PY_VERSION:-2.7}
-DCMAKE_INSTALL_PREFIX=${INSTALL_PREFIX:-/paddle/build}
-DWITH_GRPC=${grpc_flag}
-DWITH_PSCORE=${distibuted_flag}
-DWITH_GLOO=${gloo_flag}
-DWITH_LITE=${WITH_LITE:-OFF}
-DWITH_XPU=${WITH_XPU:-OFF}
......@@ -284,6 +285,7 @@ EOF
-DPY_VERSION=${PY_VERSION:-2.7} \
-DCMAKE_INSTALL_PREFIX=${INSTALL_PREFIX:-/paddle/build} \
-DWITH_GRPC=${grpc_flag} \
-DWITH_PSCORE=${distibuted_flag} \
-DWITH_GLOO=${gloo_flag} \
-DLITE_GIT_TAG=develop \
-DWITH_XPU=${WITH_XPU:-OFF} \
......
......@@ -59,7 +59,8 @@ int main(int argc, char** argv) {
std::vector<std::string> envs;
std::vector<std::string> undefok;
#if defined(PADDLE_WITH_DISTRIBUTE) && !defined(PADDLE_WITH_GRPC)
#if defined(PADDLE_WITH_DISTRIBUTE) && !defined(PADDLE_WITH_GRPC) && \
!defined(PADDLE_WITH_PSLIB)
std::string str_max_body_size;
if (google::GetCommandLineOption("max_body_size", &str_max_body_size)) {
setenv("FLAGS_max_body_size", "2147483647", 1);
......
......@@ -268,7 +268,7 @@ class Service:
def __init__(self):
self.server_class = "BrpcPsServer"
self.client_class = "BrpcPsClient"
self.service_class = "PsService"
self.service_class = "BrpcPsService"
self.start_server_port = 0
self.server_thread_num = 12
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册