未验证 提交 66f1e82f 编写于 作者: Z ziyoujiyi 提交者: GitHub

safe map in heter server (#42276)

* back fl

* delete ssl cert

* .

* make warning

* .

* unittest paral degree

* solve unittest

* heter & multi cloud commm ready

* .

* .

* arm_brpc compile

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

* only output is ok

* base is ok

* .

* .

* .

* .

* .

* .

* .

* .

* add switch server bin

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

* adapt brpc ssl

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

* fix heter_server & heter_client

* .

* .

* int->int64_t

* .

* safe map in multithread

* fix heter unitest

* .

* fix code_style

* .
上级 5063546a
...@@ -23,6 +23,8 @@ DEFINE_int32(switch_send_recv_timeout_s, 600, "switch_send_recv_timeout_s"); ...@@ -23,6 +23,8 @@ DEFINE_int32(switch_send_recv_timeout_s, 600, "switch_send_recv_timeout_s");
namespace paddle { namespace paddle {
namespace distributed { namespace distributed {
std::shared_ptr<HeterClient> HeterClient::s_instance_ = nullptr; std::shared_ptr<HeterClient> HeterClient::s_instance_ = nullptr;
std::mutex HeterClient::mtx_;
std::shared_ptr<HeterClient> HeterClient::switch_s_instance_ = nullptr;
int GetMicroId(const platform::DeviceContext& ctx, int GetMicroId(const platform::DeviceContext& ctx,
const framework::Scope* scope) { const framework::Scope* scope) {
......
...@@ -169,16 +169,22 @@ class HeterClient { ...@@ -169,16 +169,22 @@ class HeterClient {
} }
// switch client singleton // switch client singleton
static HeterClient& GetSwitchInstance( static std::shared_ptr<HeterClient> GetSwitchInstance(
const std::vector<std::string>& peer_endpoints, int32_t peer_role) { const std::vector<std::string>& peer_endpoints, int32_t peer_role) {
static HeterClient switch_s_instance_; if (switch_s_instance_ == nullptr) {
std::unique_lock<std::mutex> lock(mtx_);
if (peer_endpoints.empty()) { if (peer_endpoints.empty()) {
VLOG(4) << "init switch client failed, null peer_endpoints"; VLOG(4) << "init switch client failed, null peer_endpoints";
} }
VLOG(4) << "peer role is: " << peer_role VLOG(4) << "peer role is: " << peer_role
<< ", addr is: " << peer_endpoints[0]; << ", addr is: " << peer_endpoints[0];
switch_s_instance_.SetPeerSwitchList(peer_endpoints); if (switch_s_instance_ == nullptr) {
switch_s_instance_.InitClientChannels(false, peer_endpoints, peer_role); switch_s_instance_.reset(new HeterClient());
switch_s_instance_->SetPeerSwitchList(peer_endpoints);
switch_s_instance_->InitClientChannels(false, peer_endpoints,
peer_role);
}
}
return switch_s_instance_; return switch_s_instance_;
} }
...@@ -230,6 +236,8 @@ class HeterClient { ...@@ -230,6 +236,8 @@ class HeterClient {
HeterClient(const HeterClient&); HeterClient(const HeterClient&);
static std::shared_ptr<HeterClient> s_instance_; static std::shared_ptr<HeterClient> s_instance_;
static std::mutex mtx_;
static std::shared_ptr<HeterClient> switch_s_instance_;
std::vector<std::shared_ptr<brpc::Channel>> xpu_channels_; std::vector<std::shared_ptr<brpc::Channel>> xpu_channels_;
std::vector<std::shared_ptr<brpc::Channel>> previous_xpu_channels_; std::vector<std::shared_ptr<brpc::Channel>> previous_xpu_channels_;
......
...@@ -144,31 +144,41 @@ class SendAndRecvVariableHandler final : public ServiceHandlerBase { ...@@ -144,31 +144,41 @@ class SendAndRecvVariableHandler final : public ServiceHandlerBase {
brpc::Controller* cntl); brpc::Controller* cntl);
void WaitForVarsConsumed(int32_t group_id, const std::string& var_name) { void WaitForVarsConsumed(int32_t group_id, const std::string& var_name) {
timeline_.Start(); // timeline_.Start();
while (true) { while (true) {
{
std::lock_guard<std::mutex> lock(scope_mutex_);
if (vars_ready_flag[group_id][var_name] == 0) { if (vars_ready_flag[group_id][var_name] == 0) {
break; break;
} }
}
/*
timeline_.Pause(); timeline_.Pause();
if (timeline_.ElapsedSec() > FLAGS_switch_send_recv_timeout_s) { if (timeline_.ElapsedSec() > FLAGS_switch_send_recv_timeout_s) {
VLOG(0) << "vars not consumed exceed 10 miniutes"; VLOG(0) << "vars not consumed exceed 10 miniutes";
break; break;
} }
*/
} }
return; return;
} }
void WaitForVarsProduced(int32_t group_id, const std::string& var_name) { void WaitForVarsProduced(int32_t group_id, const std::string& var_name) {
timeline_.Start(); // timeline_.Start();
while (true) { while (true) {
{
std::lock_guard<std::mutex> lock(scope_mutex_);
if (vars_ready_flag[group_id][var_name] == 1) { if (vars_ready_flag[group_id][var_name] == 1) {
break; break;
} }
}
/*
timeline_.Pause(); timeline_.Pause();
if (timeline_.ElapsedSec() > FLAGS_switch_send_recv_timeout_s) { if (timeline_.ElapsedSec() > FLAGS_switch_send_recv_timeout_s) {
VLOG(0) << "vars not produced exceed 10 miniutes"; VLOG(0) << "vars not produced exceed 10 miniutes";
break; break;
} }
*/
} }
return; return;
} }
...@@ -379,12 +389,12 @@ class HeterService : public PsService { ...@@ -379,12 +389,12 @@ class HeterService : public PsService {
::google::protobuf::Closure* done) { ::google::protobuf::Closure* done) {
VLOG(4) << "entering SendToSwitch"; VLOG(4) << "entering SendToSwitch";
brpc::ClosureGuard done_guard(done); brpc::ClosureGuard done_guard(done);
auto& switch_client_ptr_ = std::shared_ptr<HeterClient> switch_client_ptr_ =
HeterClient::GetSwitchInstance(peer_endpoints_, PEER_ROLE_IS_SWITCH); HeterClient::GetSwitchInstance(peer_endpoints_, PEER_ROLE_IS_SWITCH);
if (switch_client_ptr_.peer_switch_channels_.empty()) { if (switch_client_ptr_->peer_switch_channels_.empty()) {
LOG(ERROR) << "switch_client_ptr_.peer_switch_channels_ null"; LOG(ERROR) << "switch_client_ptr_->peer_switch_channels_ null";
} }
brpc::Channel* channel = switch_client_ptr_.peer_switch_channels_[0].get(); brpc::Channel* channel = switch_client_ptr_->peer_switch_channels_[0].get();
brpc::Controller* cntl = static_cast<brpc::Controller*>(controller); brpc::Controller* cntl = static_cast<brpc::Controller*>(controller);
// proxy: 定义新的 OnHeterRpcDone 对象(或者在类 OnHeterRpcDone 中 reset) // proxy: 定义新的 OnHeterRpcDone 对象(或者在类 OnHeterRpcDone 中 reset)
OnHeterRpcDone* closure2 = new OnHeterRpcDone([](void* done) { OnHeterRpcDone* closure2 = new OnHeterRpcDone([](void* done) {
...@@ -414,6 +424,7 @@ class HeterService : public PsService { ...@@ -414,6 +424,7 @@ class HeterService : public PsService {
std_cntl.response_attachment().movable()); std_cntl.response_attachment().movable());
fut.wait(); fut.wait();
VLOG(4) << "SendToSwitch done"; VLOG(4) << "SendToSwitch done";
delete closure2;
} }
void SendS2S(::google::protobuf::RpcController* controller, void SendS2S(::google::protobuf::RpcController* controller,
...@@ -446,11 +457,11 @@ class HeterService : public PsService { ...@@ -446,11 +457,11 @@ class HeterService : public PsService {
brpc::ClosureGuard done_guard(done); brpc::ClosureGuard done_guard(done);
brpc::Controller* cntl = static_cast<brpc::Controller*>(controller); brpc::Controller* cntl = static_cast<brpc::Controller*>(controller);
VLOG(4) << "SendToWorker(client addr) =" << cntl->remote_side(); VLOG(4) << "SendToWorker(client addr) =" << cntl->remote_side();
auto& switch_client_ptr_ = std::shared_ptr<distributed::HeterClient> switch_client_ptr_ =
HeterClient::GetSwitchInstance(peer_endpoints_, PEER_ROLE_IS_WORKER); HeterClient::GetSwitchInstance(peer_endpoints_, PEER_ROLE_IS_WORKER);
VLOG(4) << "in switch client, peer worker 0: " VLOG(4) << "in switch client, peer worker 0: "
<< switch_client_ptr_.peer_worker_list_[0]; << switch_client_ptr_->peer_worker_list_[0];
brpc::Channel* channel = switch_client_ptr_.peer_worker_channels_[0].get(); brpc::Channel* channel = switch_client_ptr_->peer_worker_channels_[0].get();
auto* closure = reinterpret_cast<OnHeterRpcDone*>(done); auto* closure = reinterpret_cast<OnHeterRpcDone*>(done);
PsService_Stub stub(channel); PsService_Stub stub(channel);
......
...@@ -122,6 +122,7 @@ void TestShardSendRecv( ...@@ -122,6 +122,7 @@ void TestShardSendRecv(
void PressTestSendRecv( void PressTestSendRecv(
std::shared_ptr<distributed::HeterClient> heter_client_ptr_) { std::shared_ptr<distributed::HeterClient> heter_client_ptr_) {
// long l = 0, m = 0; // long l = 0, m = 0;
// https://paddlerec.bj.bcebos.com/online_infer/arm_brpc_ubuntu18/send_20_34
std::ifstream file("/send_20_34", std::ios::in | std::ios::binary); std::ifstream file("/send_20_34", std::ios::in | std::ios::binary);
// l = file.tellg(); // l = file.tellg();
// file.seekg(0, std::ios::end); // file.seekg(0, std::ios::end);
...@@ -129,13 +130,13 @@ void PressTestSendRecv( ...@@ -129,13 +130,13 @@ void PressTestSendRecv(
// file.close(); // file.close();
// VLOG(0) << "size of file " << "20_34" << " is " << (m - l) << " bytes.\n"; // VLOG(0) << "size of file " << "20_34" << " is " << (m - l) << " bytes.\n";
int64_t vars_len = 2359296 * sizeof(float); int64_t vars_len = 2359296 * sizeof(float);
int64_t data_size = vars_len * sizeof(float); int64_t data_size = vars_len;
VLOG(0) << "float num: " << data_size; VLOG(0) << "float num: " << data_size;
float* data_ptr = new float[data_size]; float* data_ptr = new float[data_size];
file.read((char*)data_ptr, 9437184); file.read((char*)data_ptr, 9437184);
VLOG(0) << "send data is: " << data_ptr[0] << ", " << data_ptr[1]; VLOG(0) << "send data is: " << data_ptr[0] << ", " << data_ptr[1];
std::vector<std::string> var_names{"34"}; std::vector<std::string> var_names{"34"};
int loopCnt = 600; int loopCnt = 10000;
auto send_async = [&]() -> void { auto send_async = [&]() -> void {
int i = 0; int i = 0;
while (i++ < loopCnt) { while (i++ < loopCnt) {
...@@ -254,8 +255,8 @@ TEST(HETERSENDANDRECV, CPU) { ...@@ -254,8 +255,8 @@ TEST(HETERSENDANDRECV, CPU) {
exe.Prepare(program, 0); // solve undefined symbol: tensor_table.cc exe.Prepare(program, 0); // solve undefined symbol: tensor_table.cc
// TestScopeSendRecv(heter_client_ptr_); // TestScopeSendRecv(heter_client_ptr_);
TestShardSendRecv(heter_client_ptr_); // TestShardSendRecv(heter_client_ptr_);
// PressTestSendRecv(heter_client_ptr_); PressTestSendRecv(heter_client_ptr_);
switch_server_ptr_a->Stop(); switch_server_ptr_a->Stop();
LOG(INFO) << "switch server A stopped"; LOG(INFO) << "switch server A stopped";
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册