diff --git a/paddle/fluid/distributed/ps/service/heter_client.h b/paddle/fluid/distributed/ps/service/heter_client.h index 36bafc943701f61715953837c22c26eb3b6d0114..efaa48470a8bd7f163c4f098749a2eb8d9d16cc4 100644 --- a/paddle/fluid/distributed/ps/service/heter_client.h +++ b/paddle/fluid/distributed/ps/service/heter_client.h @@ -171,19 +171,16 @@ class HeterClient { // switch client singleton static std::shared_ptr GetSwitchInstance( const std::vector& peer_endpoints, int32_t peer_role) { + std::unique_lock lock(mtx_); + if (peer_endpoints.empty()) { + VLOG(4) << "init switch client failed, null peer_endpoints"; + } + VLOG(4) << "peer role is: " << peer_role + << ", addr is: " << peer_endpoints[0]; if (switch_s_instance_ == nullptr) { - std::unique_lock lock(mtx_); - if (peer_endpoints.empty()) { - VLOG(4) << "init switch client failed, null peer_endpoints"; - } - VLOG(4) << "peer role is: " << peer_role - << ", addr is: " << peer_endpoints[0]; - if (switch_s_instance_ == nullptr) { - switch_s_instance_.reset(new HeterClient()); - switch_s_instance_->SetPeerSwitchList(peer_endpoints); - switch_s_instance_->InitClientChannels(false, peer_endpoints, - peer_role); - } + switch_s_instance_.reset(new HeterClient()); + switch_s_instance_->SetPeerSwitchList(peer_endpoints); + switch_s_instance_->InitClientChannels(false, peer_endpoints, peer_role); } return switch_s_instance_; } diff --git a/paddle/fluid/distributed/ps/service/heter_server.cc b/paddle/fluid/distributed/ps/service/heter_server.cc index 0753a6799c1be5c92eea71d131c0f4771ab6d204..fd38a030ff366abc6bd8534af13501355b034968 100755 --- a/paddle/fluid/distributed/ps/service/heter_server.cc +++ b/paddle/fluid/distributed/ps/service/heter_server.cc @@ -125,6 +125,9 @@ int SendAndRecvVariableHandler::SaveInSwitchWithShard( brpc::Controller* cntl) { VLOG(4) << "entering SaveInSwitchWithShard"; int32_t group_id = request->group_id(); + if (group_id >= FLAGS_heter_world_size) { + LOG(ERROR) << "group id exceed maxmium"; + } auto& local_shard = _local_shards[group_id]; auto& request_io_buffer = cntl->request_attachment(); butil::IOBufBytesIterator io_buffer_itr(request_io_buffer); @@ -132,11 +135,11 @@ int SendAndRecvVariableHandler::SaveInSwitchWithShard( const auto& var_name = request->send_var_names(idx); const auto& var_size = request->vars_len(idx); WaitForVarsConsumed(group_id, var_name); + std::unique_lock lk(scope_mutex_); auto& value = local_shard[var_name]; value.resize(var_size); io_buffer_itr.copy_and_forward(reinterpret_cast(value.data()), var_size); - std::unique_lock lk(scope_mutex_); vars_ready_flag[group_id][var_name] = 1; VLOG(4) << "saved var_name: " << var_name << "is saved ready!"; } @@ -162,11 +165,11 @@ int SendAndRecvVariableHandler::QueryInSwitchWithShard( VLOG(4) << "req var name: " << req_var_name; response->add_send_var_names(req_var_name); WaitForVarsProduced(group_id, req_var_name); + std::unique_lock lk(scope_mutex_); auto itr = local_shard.find(req_var_name); auto& value = itr.value(); response_io_buffer.append(value.data(), value.size()); value.resize(0); // 清空内存 - std::unique_lock lk(scope_mutex_); vars_ready_flag[group_id][req_var_name] = 0; VLOG(4) << "query var_name: " << req_var_name << "is consumed ready!"; }