From e2540c17f3788da039cd5eb3bf173ee2b42ba366 Mon Sep 17 00:00:00 2001 From: ziyoujiyi <73728031+ziyoujiyi@users.noreply.github.com> Date: Tue, 10 May 2022 15:16:45 +0800 Subject: [PATCH] fix switch client multithread bug (#42600) * back fl * delete ssl cert * . * make warning * . * unittest paral degree * solve unittest * heter & multi cloud commm ready * . * . * arm_brpc compile * . * . * . * . * . * . * . * . * . * . * . * . * . * . * only output is ok * base is ok * . * . * . * . * . * . * . * . * add switch server bin * . * . * . * . * . * . * . * . * . * . * . * . * . * . * . * . * . * . * . * . * . * . * . * . * . * . * . * . * . * . * . * . * . * . * . * adapt brpc ssl * . * . * . * . * . * . * . * . * . * . * . * . * . * . * . * . * . * . * . * . * . * . * . * . * . * . * . * . * . * . * . * . * . * . * . * fix heter_server & heter_client * . * . * int->int64_t * . * safe map in multithread * fix heter unitest * . * fix code_style * . * fix bug * . --- .../distributed/ps/service/heter_client.h | 21 ++++++++----------- .../distributed/ps/service/heter_server.cc | 7 +++++-- 2 files changed, 14 insertions(+), 14 deletions(-) diff --git a/paddle/fluid/distributed/ps/service/heter_client.h b/paddle/fluid/distributed/ps/service/heter_client.h index 36bafc94370..efaa48470a8 100644 --- a/paddle/fluid/distributed/ps/service/heter_client.h +++ b/paddle/fluid/distributed/ps/service/heter_client.h @@ -171,19 +171,16 @@ class HeterClient { // switch client singleton static std::shared_ptr GetSwitchInstance( const std::vector& peer_endpoints, int32_t peer_role) { + std::unique_lock lock(mtx_); + if (peer_endpoints.empty()) { + VLOG(4) << "init switch client failed, null peer_endpoints"; + } + VLOG(4) << "peer role is: " << peer_role + << ", addr is: " << peer_endpoints[0]; if (switch_s_instance_ == nullptr) { - std::unique_lock lock(mtx_); - if (peer_endpoints.empty()) { - VLOG(4) << "init switch client failed, null peer_endpoints"; - } - VLOG(4) << "peer role is: " << peer_role - << ", addr is: " << peer_endpoints[0]; - if (switch_s_instance_ == nullptr) { - switch_s_instance_.reset(new HeterClient()); - switch_s_instance_->SetPeerSwitchList(peer_endpoints); - switch_s_instance_->InitClientChannels(false, peer_endpoints, - peer_role); - } + switch_s_instance_.reset(new HeterClient()); + switch_s_instance_->SetPeerSwitchList(peer_endpoints); + switch_s_instance_->InitClientChannels(false, peer_endpoints, peer_role); } return switch_s_instance_; } diff --git a/paddle/fluid/distributed/ps/service/heter_server.cc b/paddle/fluid/distributed/ps/service/heter_server.cc index 0753a6799c1..fd38a030ff3 100755 --- a/paddle/fluid/distributed/ps/service/heter_server.cc +++ b/paddle/fluid/distributed/ps/service/heter_server.cc @@ -125,6 +125,9 @@ int SendAndRecvVariableHandler::SaveInSwitchWithShard( brpc::Controller* cntl) { VLOG(4) << "entering SaveInSwitchWithShard"; int32_t group_id = request->group_id(); + if (group_id >= FLAGS_heter_world_size) { + LOG(ERROR) << "group id exceed maxmium"; + } auto& local_shard = _local_shards[group_id]; auto& request_io_buffer = cntl->request_attachment(); butil::IOBufBytesIterator io_buffer_itr(request_io_buffer); @@ -132,11 +135,11 @@ int SendAndRecvVariableHandler::SaveInSwitchWithShard( const auto& var_name = request->send_var_names(idx); const auto& var_size = request->vars_len(idx); WaitForVarsConsumed(group_id, var_name); + std::unique_lock lk(scope_mutex_); auto& value = local_shard[var_name]; value.resize(var_size); io_buffer_itr.copy_and_forward(reinterpret_cast(value.data()), var_size); - std::unique_lock lk(scope_mutex_); vars_ready_flag[group_id][var_name] = 1; VLOG(4) << "saved var_name: " << var_name << "is saved ready!"; } @@ -162,11 +165,11 @@ int SendAndRecvVariableHandler::QueryInSwitchWithShard( VLOG(4) << "req var name: " << req_var_name; response->add_send_var_names(req_var_name); WaitForVarsProduced(group_id, req_var_name); + std::unique_lock lk(scope_mutex_); auto itr = local_shard.find(req_var_name); auto& value = itr.value(); response_io_buffer.append(value.data(), value.size()); value.resize(0); // 清空内存 - std::unique_lock lk(scope_mutex_); vars_ready_flag[group_id][req_var_name] = 0; VLOG(4) << "query var_name: " << req_var_name << "is consumed ready!"; } -- GitLab