未验证 提交 7d8402a8 编写于 作者: H huangjiyi 提交者: GitHub

fix paddle namespace conflict when using paddle_flags (#56913)

* update

* update

* update
上级 25a0b46d
...@@ -85,7 +85,7 @@ void HeterClient::CreateClient2XpuConnection() { ...@@ -85,7 +85,7 @@ void HeterClient::CreateClient2XpuConnection() {
xpu_channels_[i].reset(new brpc::Channel()); xpu_channels_[i].reset(new brpc::Channel());
if (xpu_channels_[i]->Init(xpu_list_[i].c_str(), "", &options) != 0) { if (xpu_channels_[i]->Init(xpu_list_[i].c_str(), "", &options) != 0) {
VLOG(0) << "HeterClient channel init fail. Try Again"; VLOG(0) << "HeterClient channel init fail. Try Again";
auto ip_port = ::paddle::string::Split(xpu_list_[i], ':'); auto ip_port = paddle::string::Split(xpu_list_[i], ':');
std::string ip = ip_port[0]; std::string ip = ip_port[0];
int port = std::stoi(ip_port[1]); int port = std::stoi(ip_port[1]);
std::string int_ip_port = GetIntTypeEndpoint(ip, port); std::string int_ip_port = GetIntTypeEndpoint(ip, port);
...@@ -100,7 +100,7 @@ void HeterClient::CreateClient2XpuConnection() { ...@@ -100,7 +100,7 @@ void HeterClient::CreateClient2XpuConnection() {
if (previous_xpu_channels_[i]->Init( if (previous_xpu_channels_[i]->Init(
previous_xpu_list_[i].c_str(), "", &options) != 0) { previous_xpu_list_[i].c_str(), "", &options) != 0) {
VLOG(0) << "HeterClient channel init fail. Try Again"; VLOG(0) << "HeterClient channel init fail. Try Again";
auto ip_port = ::paddle::string::Split(previous_xpu_list_[i], ':'); auto ip_port = paddle::string::Split(previous_xpu_list_[i], ':');
std::string ip = ip_port[0]; std::string ip = ip_port[0];
int port = std::stoi(ip_port[1]); int port = std::stoi(ip_port[1]);
std::string int_ip_port = GetIntTypeEndpoint(ip, port); std::string int_ip_port = GetIntTypeEndpoint(ip, port);
...@@ -167,13 +167,13 @@ void HeterClient::SendAndRecvAsync( ...@@ -167,13 +167,13 @@ void HeterClient::SendAndRecvAsync(
// int idx = 1; // for test // int idx = 1; // for test
// LOG(INFO) << "xpu_channels_ size: " << xpu_channels_.size(); // LOG(INFO) << "xpu_channels_ size: " << xpu_channels_.size();
// channel = xpu_channels_[idx].get(); // 为了适配 send_and_recv op // channel = xpu_channels_[idx].get(); // 为了适配 send_and_recv op
// ::paddle::distributed::PsService_Stub stub(channel); // paddle::distributed::PsService_Stub stub(channel);
// stub.SendToSwitch(&closure->cntl, &request, &closure->response, // stub.SendToSwitch(&closure->cntl, &request, &closure->response,
// closure); fut.wait(); // closure); fut.wait();
VLOG(4) << "calling switch service done"; VLOG(4) << "calling switch service done";
return; return;
} }
::paddle::distributed::PsService_Stub stub(channel); paddle::distributed::PsService_Stub stub(channel);
stub.SendAndRecvVariable( stub.SendAndRecvVariable(
&closure->cntl, &request, &closure->response, closure); &closure->cntl, &request, &closure->response, closure);
} }
...@@ -181,11 +181,11 @@ void HeterClient::SendAndRecvAsync( ...@@ -181,11 +181,11 @@ void HeterClient::SendAndRecvAsync(
std::future<int32_t> HeterClient::SendCmd( std::future<int32_t> HeterClient::SendCmd(
uint32_t table_id, int cmd_id, const std::vector<std::string>& params) { uint32_t table_id, int cmd_id, const std::vector<std::string>& params) {
size_t request_call_num = xpu_channels_.size(); size_t request_call_num = xpu_channels_.size();
::paddle::distributed::DownpourBrpcClosure* closure = paddle::distributed::DownpourBrpcClosure* closure =
new ::paddle::distributed::DownpourBrpcClosure( new paddle::distributed::DownpourBrpcClosure(
request_call_num, [request_call_num, cmd_id](void* done) { request_call_num, [request_call_num, cmd_id](void* done) {
int ret = 0; int ret = 0;
auto* closure = (::paddle::distributed::DownpourBrpcClosure*)done; auto* closure = (paddle::distributed::DownpourBrpcClosure*)done;
for (size_t i = 0; i < request_call_num; ++i) { for (size_t i = 0; i < request_call_num; ++i) {
if (closure->check_response(i, cmd_id) != 0) { if (closure->check_response(i, cmd_id) != 0) {
ret = -1; ret = -1;
...@@ -204,7 +204,7 @@ std::future<int32_t> HeterClient::SendCmd( ...@@ -204,7 +204,7 @@ std::future<int32_t> HeterClient::SendCmd(
for (const auto& param : params) { for (const auto& param : params) {
closure->request(i)->add_params(param); closure->request(i)->add_params(param);
} }
::paddle::distributed::PsService_Stub rpc_stub(xpu_channels_[i].get()); paddle::distributed::PsService_Stub rpc_stub(xpu_channels_[i].get());
closure->cntl(i)->set_timeout_ms( closure->cntl(i)->set_timeout_ms(
FLAGS_pserver_timeout_ms); // cmd msg don't limit timeout for save/load FLAGS_pserver_timeout_ms); // cmd msg don't limit timeout for save/load
rpc_stub.service( rpc_stub.service(
...@@ -270,7 +270,7 @@ int HeterClient::Send(const platform::DeviceContext& ctx, ...@@ -270,7 +270,7 @@ int HeterClient::Send(const platform::DeviceContext& ctx,
} }
brpc::Channel* channel = send_switch_channels_[0].get(); brpc::Channel* channel = send_switch_channels_[0].get();
// brpc::Channel* channel = xpu_channels_[0].get(); // brpc::Channel* channel = xpu_channels_[0].get();
::paddle::distributed::PsService_Stub stub(channel); paddle::distributed::PsService_Stub stub(channel);
stub.SendToSwitch(&closure->cntl, &request, &closure->ps_response, closure); stub.SendToSwitch(&closure->cntl, &request, &closure->ps_response, closure);
VLOG(4) << "waiting SendToSwitch response result......"; VLOG(4) << "waiting SendToSwitch response result......";
...@@ -317,7 +317,7 @@ int HeterClient::Send(int group_id, ...@@ -317,7 +317,7 @@ int HeterClient::Send(int group_id,
send_switch_channels_.push_back(xpu_channels_[0]); send_switch_channels_.push_back(xpu_channels_[0]);
} }
brpc::Channel* channel = send_switch_channels_[0].get(); brpc::Channel* channel = send_switch_channels_[0].get();
::paddle::distributed::PsService_Stub stub(channel); paddle::distributed::PsService_Stub stub(channel);
stub.SendToSwitch(&closure->cntl, &request, &closure->ps_response, closure); stub.SendToSwitch(&closure->cntl, &request, &closure->ps_response, closure);
fut.wait(); fut.wait();
delete closure; delete closure;
...@@ -362,7 +362,7 @@ int HeterClient::Recv(const platform::DeviceContext& ctx, ...@@ -362,7 +362,7 @@ int HeterClient::Recv(const platform::DeviceContext& ctx,
recv_switch_channels_.push_back(xpu_channels_[1]); recv_switch_channels_.push_back(xpu_channels_[1]);
} }
brpc::Channel* channel = recv_switch_channels_[0].get(); brpc::Channel* channel = recv_switch_channels_[0].get();
::paddle::distributed::PsService_Stub stub(channel); paddle::distributed::PsService_Stub stub(channel);
stub.RecvFromSwitch(&closure->cntl, &request, &closure->response, closure); stub.RecvFromSwitch(&closure->cntl, &request, &closure->response, closure);
fut.wait(); fut.wait();
VLOG(4) << "RecvFromSwitch done"; VLOG(4) << "RecvFromSwitch done";
...@@ -412,7 +412,7 @@ int HeterClient::Recv(int group_id, ...@@ -412,7 +412,7 @@ int HeterClient::Recv(int group_id,
recv_switch_channels_.push_back(xpu_channels_[0]); recv_switch_channels_.push_back(xpu_channels_[0]);
} }
brpc::Channel* channel = recv_switch_channels_[0].get(); brpc::Channel* channel = recv_switch_channels_[0].get();
::paddle::distributed::PsService_Stub stub(channel); paddle::distributed::PsService_Stub stub(channel);
stub.RecvFromSwitch(&closure->cntl, &request, &closure->response, closure); stub.RecvFromSwitch(&closure->cntl, &request, &closure->response, closure);
fut.wait(); fut.wait();
VLOG(4) << "RecvFromSwitch done"; VLOG(4) << "RecvFromSwitch done";
......
...@@ -59,24 +59,20 @@ ...@@ -59,24 +59,20 @@
#else // PADDLE_WITH_GFLAGS #else // PADDLE_WITH_GFLAGS
#define PHI_DECLARE_VARIABLE(type, shorttype, name) \ #define PHI_DECLARE_VARIABLE(type, shorttype, name) \
namespace paddle { \ namespace paddle_flags { \
namespace flags { \
extern PHI_IMPORT_FLAG type FLAGS_##name; \ extern PHI_IMPORT_FLAG type FLAGS_##name; \
} \ } \
} \ using paddle_flags::FLAGS_##name
using paddle::flags::FLAGS_##name
#define PHI_DEFINE_VARIABLE(type, shorttype, name, default_value, description) \ #define PHI_DEFINE_VARIABLE(type, shorttype, name, default_value, description) \
namespace paddle { \ namespace paddle_flags { \
namespace flags { \
static const type FLAGS_##name##_default = default_value; \ static const type FLAGS_##name##_default = default_value; \
PHI_EXPORT_FLAG type FLAGS_##name = default_value; \ PHI_EXPORT_FLAG type FLAGS_##name = default_value; \
/* Register FLAG */ \ /* Register FLAG */ \
static ::paddle::flags::FlagRegisterer flag_##name##_registerer( \ static ::paddle::flags::FlagRegisterer flag_##name##_registerer( \
#name, description, __FILE__, &FLAGS_##name##_default, &FLAGS_##name); \ #name, description, __FILE__, &FLAGS_##name##_default, &FLAGS_##name); \
} \ } \
} \ using paddle_flags::FLAGS_##name
using paddle::flags::FLAGS_##name
#endif #endif
......
...@@ -74,12 +74,10 @@ void PrintAllFlagHelp(bool to_file = false, ...@@ -74,12 +74,10 @@ void PrintAllFlagHelp(bool to_file = false,
// ----------------------------DECLARE FLAGS---------------------------- // ----------------------------DECLARE FLAGS----------------------------
#define PD_DECLARE_VARIABLE(type, name) \ #define PD_DECLARE_VARIABLE(type, name) \
namespace paddle { \ namespace paddle_flags { \
namespace flags { \
extern type FLAGS_##name; \ extern type FLAGS_##name; \
} \ } \
} \ using paddle_flags::FLAGS_##name
using paddle::flags::FLAGS_##name
#define PD_DECLARE_bool(name) PD_DECLARE_VARIABLE(bool, name) #define PD_DECLARE_bool(name) PD_DECLARE_VARIABLE(bool, name)
#define PD_DECLARE_int32(name) PD_DECLARE_VARIABLE(int32_t, name) #define PD_DECLARE_int32(name) PD_DECLARE_VARIABLE(int32_t, name)
...@@ -105,16 +103,14 @@ class FlagRegisterer { ...@@ -105,16 +103,14 @@ class FlagRegisterer {
// ----------------------------DEFINE FLAGS---------------------------- // ----------------------------DEFINE FLAGS----------------------------
#define PD_DEFINE_VARIABLE(type, name, default_value, description) \ #define PD_DEFINE_VARIABLE(type, name, default_value, description) \
namespace paddle { \ namespace paddle_flags { \
namespace flags { \
static const type FLAGS_##name##_default = default_value; \ static const type FLAGS_##name##_default = default_value; \
type FLAGS_##name = default_value; \ type FLAGS_##name = default_value; \
/* Register FLAG */ \ /* Register FLAG */ \
static ::paddle::flags::FlagRegisterer flag_##name##_registerer( \ static ::paddle::flags::FlagRegisterer flag_##name##_registerer( \
#name, description, __FILE__, &FLAGS_##name##_default, &FLAGS_##name); \ #name, description, __FILE__, &FLAGS_##name##_default, &FLAGS_##name); \
} \ } \
} \ using paddle_flags::FLAGS_##name
using paddle::flags::FLAGS_##name
#define PD_DEFINE_bool(name, val, txt) PD_DEFINE_VARIABLE(bool, name, val, txt) #define PD_DEFINE_bool(name, val, txt) PD_DEFINE_VARIABLE(bool, name, val, txt)
#define PD_DEFINE_int32(name, val, txt) \ #define PD_DEFINE_int32(name, val, txt) \
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册