未验证 提交 94cb210b 编写于 作者: C Chengmo 提交者: GitHub

【Cherry-pick】Fix Parameter Server Bug (#30860)

* 【Paddle.Fleet】Fix brpc get hostname (#30703)

* fix Brpc get hostname

* fix int64 bug (#30780)

fix push sparse int64 bug
上级 b5df2dea
...@@ -25,9 +25,10 @@ set_source_files_properties(client.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMP ...@@ -25,9 +25,10 @@ set_source_files_properties(client.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMP
set_source_files_properties(ps_client.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(ps_client.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(server.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(server.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_library(brpc_utils SRCS brpc_utils.cc DEPS tensor device_context ${COMMON_DEPS} ${RPC_DEPS})
cc_library(downpour_server SRCS brpc_ps_server.cc DEPS boost eigen3 table ${RPC_DEPS}) cc_library(downpour_server SRCS brpc_ps_server.cc DEPS boost eigen3 table brpc_utils ${RPC_DEPS})
cc_library(downpour_client SRCS brpc_ps_client.cc DEPS boost eigen3 table ${RPC_DEPS}) cc_library(downpour_client SRCS brpc_ps_client.cc DEPS boost eigen3 table brpc_utils ${RPC_DEPS})
cc_library(client SRCS ps_client.cc DEPS downpour_client boost ${RPC_DEPS}) cc_library(client SRCS ps_client.cc DEPS downpour_client boost ${RPC_DEPS})
cc_library(server SRCS server.cc DEPS downpour_server boost ${RPC_DEPS}) cc_library(server SRCS server.cc DEPS downpour_server boost ${RPC_DEPS})
...@@ -35,6 +36,5 @@ cc_library(server SRCS server.cc DEPS downpour_server boost ${RPC_DEPS}) ...@@ -35,6 +36,5 @@ cc_library(server SRCS server.cc DEPS downpour_server boost ${RPC_DEPS})
cc_library(communicator SRCS communicator.cc DEPS scope client boost table math_function selected_rows_functor ${RPC_DEPS}) cc_library(communicator SRCS communicator.cc DEPS scope client boost table math_function selected_rows_functor ${RPC_DEPS})
cc_library(ps_service SRCS service.cc DEPS communicator client server boost ${RPC_DEPS}) cc_library(ps_service SRCS service.cc DEPS communicator client server boost ${RPC_DEPS})
cc_library(brpc_utils SRCS brpc_utils.cc DEPS tensor device_context ${COMMON_DEPS} ${RPC_DEPS})
cc_library(heter_server SRCS heter_server.cc DEPS brpc_utils ${COMMON_DEPS} ${RPC_DEPS}) cc_library(heter_server SRCS heter_server.cc DEPS brpc_utils ${COMMON_DEPS} ${RPC_DEPS})
cc_library(heter_client SRCS heter_client.cc DEPS brpc_utils ${COMMON_DEPS} ${RPC_DEPS}) cc_library(heter_client SRCS heter_client.cc DEPS brpc_utils ${COMMON_DEPS} ${RPC_DEPS})
...@@ -134,8 +134,15 @@ int32_t BrpcPsClient::create_client2client_connection( ...@@ -134,8 +134,15 @@ int32_t BrpcPsClient::create_client2client_connection(
server_ip_port.append(std::to_string(client_list[i].port)); server_ip_port.append(std::to_string(client_list[i].port));
_client_channels[i].reset(new brpc::Channel()); _client_channels[i].reset(new brpc::Channel());
if (_client_channels[i]->Init(server_ip_port.c_str(), "", &options) != 0) { if (_client_channels[i]->Init(server_ip_port.c_str(), "", &options) != 0) {
LOG(ERROR) << "psclient connect to client:" << server_ip_port VLOG(0) << "BrpcPSClient connect to Client:" << server_ip_port
<< " Failed!"; << " Failed! Try again.";
std::string int_ip_port =
GetIntTypeEndpoint(client_list[i].ip, client_list[i].port);
if (_client_channels[i]->Init(int_ip_port.c_str(), "", &options) != 0) {
LOG(ERROR) << "BrpcPSClient connect to Client:" << int_ip_port
<< " Failed!";
return -1;
}
} }
os << server_ip_port << ","; os << server_ip_port << ",";
} }
...@@ -168,9 +175,16 @@ int32_t BrpcPsClient::initialize() { ...@@ -168,9 +175,16 @@ int32_t BrpcPsClient::initialize() {
_server_channels[i][j].reset(new brpc::Channel()); _server_channels[i][j].reset(new brpc::Channel());
if (_server_channels[i][j]->Init(server_ip_port.c_str(), "", &options) != if (_server_channels[i][j]->Init(server_ip_port.c_str(), "", &options) !=
0) { 0) {
LOG(ERROR) << "psclient connect to server:" << server_ip_port VLOG(0) << "BrpcPSclient connect to Server:" << server_ip_port
<< " Failed!"; << " Failed! Try again.";
return -1; std::string int_ip_port =
GetIntTypeEndpoint(server_list[i].ip, server_list[i].port);
if (_server_channels[i][j]->Init(int_ip_port.c_str(), "", &options) !=
0) {
LOG(ERROR) << "BrpcPSclient connect to Server:" << int_ip_port
<< " Failed!";
return -1;
}
} }
} }
os << server_ip_port << ","; os << server_ip_port << ",";
......
...@@ -21,6 +21,7 @@ ...@@ -21,6 +21,7 @@
#include "brpc/channel.h" #include "brpc/channel.h"
#include "brpc/controller.h" #include "brpc/controller.h"
#include "brpc/server.h" #include "brpc/server.h"
#include "paddle/fluid/distributed/service/brpc_utils.h"
#include "paddle/fluid/distributed/service/ps_client.h" #include "paddle/fluid/distributed/service/ps_client.h"
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/scope.h"
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
// limitations under the License. // limitations under the License.
#include "paddle/fluid/distributed/service/brpc_ps_server.h" #include "paddle/fluid/distributed/service/brpc_ps_server.h"
#include <netdb.h>
#include <thread> // NOLINT #include <thread> // NOLINT
#include "Eigen/Dense" #include "Eigen/Dense"
#include "butil/endpoint.h" #include "butil/endpoint.h"
...@@ -65,9 +65,17 @@ uint64_t BrpcPsServer::start(const std::string &ip, uint32_t port) { ...@@ -65,9 +65,17 @@ uint64_t BrpcPsServer::start(const std::string &ip, uint32_t port) {
options.num_threads = trainers > num_threads ? trainers : num_threads; options.num_threads = trainers > num_threads ? trainers : num_threads;
if (_server.Start(ip_port.c_str(), &options) != 0) { if (_server.Start(ip_port.c_str(), &options) != 0) {
LOG(ERROR) << "BrpcPsServer start failed, ip_port=" << ip_port; VLOG(0) << "BrpcPsServer start failed, ip_port= " << ip_port
return 0; << " , Try Again.";
std::string int_ip_port = GetIntTypeEndpoint(ip, port);
if (_server.Start(int_ip_port.c_str(), &options) != 0) {
LOG(ERROR) << "BrpcPsServer start failed, ip_port= " << int_ip_port;
return 0;
}
} }
VLOG(0) << "BrpcPsServer::start registe_ps_server"; VLOG(0) << "BrpcPsServer::start registe_ps_server";
_environment->registe_ps_server(ip, port, _rank); _environment->registe_ps_server(ip, port, _rank);
VLOG(0) << "BrpcPsServer::start wait"; VLOG(0) << "BrpcPsServer::start wait";
......
...@@ -20,6 +20,7 @@ ...@@ -20,6 +20,7 @@
#include <memory> #include <memory>
#include <vector> #include <vector>
#include "paddle/fluid/distributed/service/brpc_utils.h"
#include "paddle/fluid/distributed/service/server.h" #include "paddle/fluid/distributed/service/server.h"
namespace paddle { namespace paddle {
...@@ -43,7 +44,6 @@ class BrpcPsServer : public PSServer { ...@@ -43,7 +44,6 @@ class BrpcPsServer : public PSServer {
private: private:
virtual int32_t initialize(); virtual int32_t initialize();
mutable std::mutex mutex_; mutable std::mutex mutex_;
std::condition_variable cv_; std::condition_variable cv_;
bool stoped_ = false; bool stoped_ = false;
......
...@@ -13,6 +13,9 @@ See the License for the specific language governing permissions and ...@@ -13,6 +13,9 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/distributed/service/brpc_utils.h" #include "paddle/fluid/distributed/service/brpc_utils.h"
#include <arpa/inet.h>
#include <netdb.h>
#include <netinet/in.h>
#include <limits> #include <limits>
#include <memory> #include <memory>
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
...@@ -310,5 +313,32 @@ void DeserializeSelectedRows(framework::Variable* var, const VarMsg& msg, ...@@ -310,5 +313,32 @@ void DeserializeSelectedRows(framework::Variable* var, const VarMsg& msg,
} }
} }
std::string GetIntTypeEndpoint(const std::string& ip, const uint32_t& port) {
// There are usually two forms of IP address: ip(int) / ip (hostname)
// If there're some problem with DNS, or ip triggers the bug of Brpc
// We will try to get the IP address of the domain name manually again
std::string ip_port = ip + ":" + std::to_string(port);
struct hostent* hp = NULL;
hp = gethostbyname(ip.c_str());
if (NULL == hp) {
LOG(ERROR) << "Brpc Start failed, ip_port= " << ip_port
<< " , Error infomation: " << hstrerror(h_errno);
}
int i = 0;
char* int_ip = NULL;
while (hp->h_addr_list[i] != NULL) {
int_ip = inet_ntoa(*(struct in_addr*)hp->h_addr_list[i]);
VLOG(0) << "Brpc Get host by name, host:" << ip << " -> ip: " << int_ip;
break;
}
std::string str_ip = int_ip;
std::string int_ip_port = str_ip + ":" + std::to_string(port);
return int_ip_port;
}
} // namespace distributed } // namespace distributed
} // namespace paddle } // namespace paddle
...@@ -14,10 +14,10 @@ limitations under the License. */ ...@@ -14,10 +14,10 @@ limitations under the License. */
#pragma once #pragma once
#include <netdb.h>
#include <iostream> #include <iostream>
#include <string> #include <string>
#include <vector> #include <vector>
#include "brpc/channel.h" #include "brpc/channel.h"
#include "paddle/fluid/distributed/service/sendrecv.pb.h" #include "paddle/fluid/distributed/service/sendrecv.pb.h"
#include "paddle/fluid/framework/data_type.h" #include "paddle/fluid/framework/data_type.h"
...@@ -82,5 +82,7 @@ void DeserializeSelectedRows(framework::Variable* var, const VarMsg& msg, ...@@ -82,5 +82,7 @@ void DeserializeSelectedRows(framework::Variable* var, const VarMsg& msg,
butil::IOBufBytesIterator& iobuf, butil::IOBufBytesIterator& iobuf,
const platform::DeviceContext& ctx); const platform::DeviceContext& ctx);
std::string GetIntTypeEndpoint(const std::string& ip, const uint32_t& port);
} // namespace distributed } // namespace distributed
} // namespace paddle } // namespace paddle
...@@ -290,7 +290,7 @@ void Communicator::RpcSendSparse(const std::string &var_name, int table_id, ...@@ -290,7 +290,7 @@ void Communicator::RpcSendSparse(const std::string &var_name, int table_id,
auto dim = tensor->value().dims()[1]; auto dim = tensor->value().dims()[1];
std::transform(tensor->rows().begin(), tensor->rows().end(), std::transform(tensor->rows().begin(), tensor->rows().end(),
std::back_inserter(sparse_push_keys), std::back_inserter(sparse_push_keys),
[&](int id) { return static_cast<uint64_t>(id); }); [&](int64_t id) { return static_cast<uint64_t>(id); });
for (auto i = 0; i < static_cast<int>(sparse_push_keys.size()); ++i) { for (auto i = 0; i < static_cast<int>(sparse_push_keys.size()); ++i) {
push_g_vec.push_back(tensor->mutable_value()->data<float>() + i * dim); push_g_vec.push_back(tensor->mutable_value()->data<float>() + i * dim);
......
...@@ -22,6 +22,7 @@ ...@@ -22,6 +22,7 @@
#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/platform/profiler.h" #include "paddle/fluid/platform/profiler.h"
#include "paddle/fluid/platform/timer.h" #include "paddle/fluid/platform/timer.h"
#include "paddle/fluid/string/split.h"
DECLARE_int32(rpc_deadline); DECLARE_int32(rpc_deadline);
DECLARE_int32(pserver_timeout_ms); DECLARE_int32(pserver_timeout_ms);
...@@ -96,7 +97,14 @@ void HeterClient::CreateClient2XpuConnection() { ...@@ -96,7 +97,14 @@ void HeterClient::CreateClient2XpuConnection() {
for (size_t i = 0; i < xpu_list_.size(); ++i) { for (size_t i = 0; i < xpu_list_.size(); ++i) {
xpu_channels_[i].reset(new brpc::Channel()); xpu_channels_[i].reset(new brpc::Channel());
if (xpu_channels_[i]->Init(xpu_list_[i].c_str(), "", &options) != 0) { if (xpu_channels_[i]->Init(xpu_list_[i].c_str(), "", &options) != 0) {
VLOG(0) << "HeterServer channel init fail"; VLOG(0) << "HeterClient channel init fail. Try Again";
auto ip_port = paddle::string::Split(xpu_list_[i], ':');
std::string ip = ip_port[0];
int port = std::stoi(ip_port[1]);
std::string int_ip_port = GetIntTypeEndpoint(ip, port);
if (xpu_channels_[i]->Init(int_ip_port.c_str(), "", &options) != 0) {
LOG(ERROR) << "BrpcPsServer start failed, ip_port= " << int_ip_port;
}
} }
} }
} }
......
...@@ -19,6 +19,7 @@ ...@@ -19,6 +19,7 @@
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/platform/timer.h" #include "paddle/fluid/platform/timer.h"
#include "paddle/fluid/string/split.h"
namespace paddle { namespace paddle {
namespace distributed { namespace distributed {
...@@ -34,7 +35,14 @@ void HeterServer::StartHeterService() { ...@@ -34,7 +35,14 @@ void HeterServer::StartHeterService() {
server_.AddService(&service_, brpc::SERVER_DOESNT_OWN_SERVICE); server_.AddService(&service_, brpc::SERVER_DOESNT_OWN_SERVICE);
brpc::ServerOptions options; brpc::ServerOptions options;
if (server_.Start(endpoint_.c_str(), &options) != 0) { if (server_.Start(endpoint_.c_str(), &options) != 0) {
VLOG(0) << "heter server start fail"; VLOG(0) << "HeterServer start fail. Try again.";
auto ip_port = paddle::string::Split(endpoint_, ':');
std::string ip = ip_port[0];
int port = std::stoi(ip_port[1]);
std::string int_ip_port = GetIntTypeEndpoint(ip, port);
if (server_.Start(endpoint_.c_str(), &options) != 0) {
LOG(ERROR) << "HeterServer start failed, ip_port= " << int_ip_port;
}
} else { } else {
VLOG(0) << "heter server start success! listen on " << endpoint_; VLOG(0) << "heter server start success! listen on " << endpoint_;
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册