未验证 提交 e2540c17 编写于 作者: Z ziyoujiyi 提交者: GitHub

fix switch client multithread bug (#42600)

* back fl

* delete ssl cert

* .

* make warning

* .

* unittest paral degree

* solve unittest

* heter & multi cloud commm ready

* .

* .

* arm_brpc compile

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

* only output is ok

* base is ok

* .

* .

* .

* .

* .

* .

* .

* .

* add switch server bin

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

* adapt brpc ssl

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

* .

* fix heter_server & heter_client

* .

* .

* int->int64_t

* .

* safe map in multithread

* fix heter unitest

* .

* fix code_style

* .

* fix bug

* .
上级 d47690b2
...@@ -171,19 +171,16 @@ class HeterClient { ...@@ -171,19 +171,16 @@ class HeterClient {
// switch client singleton // switch client singleton
static std::shared_ptr<HeterClient> GetSwitchInstance( static std::shared_ptr<HeterClient> GetSwitchInstance(
const std::vector<std::string>& peer_endpoints, int32_t peer_role) { const std::vector<std::string>& peer_endpoints, int32_t peer_role) {
std::unique_lock<std::mutex> lock(mtx_);
if (peer_endpoints.empty()) {
VLOG(4) << "init switch client failed, null peer_endpoints";
}
VLOG(4) << "peer role is: " << peer_role
<< ", addr is: " << peer_endpoints[0];
if (switch_s_instance_ == nullptr) { if (switch_s_instance_ == nullptr) {
std::unique_lock<std::mutex> lock(mtx_); switch_s_instance_.reset(new HeterClient());
if (peer_endpoints.empty()) { switch_s_instance_->SetPeerSwitchList(peer_endpoints);
VLOG(4) << "init switch client failed, null peer_endpoints"; switch_s_instance_->InitClientChannels(false, peer_endpoints, peer_role);
}
VLOG(4) << "peer role is: " << peer_role
<< ", addr is: " << peer_endpoints[0];
if (switch_s_instance_ == nullptr) {
switch_s_instance_.reset(new HeterClient());
switch_s_instance_->SetPeerSwitchList(peer_endpoints);
switch_s_instance_->InitClientChannels(false, peer_endpoints,
peer_role);
}
} }
return switch_s_instance_; return switch_s_instance_;
} }
......
...@@ -125,6 +125,9 @@ int SendAndRecvVariableHandler::SaveInSwitchWithShard( ...@@ -125,6 +125,9 @@ int SendAndRecvVariableHandler::SaveInSwitchWithShard(
brpc::Controller* cntl) { brpc::Controller* cntl) {
VLOG(4) << "entering SaveInSwitchWithShard"; VLOG(4) << "entering SaveInSwitchWithShard";
int32_t group_id = request->group_id(); int32_t group_id = request->group_id();
if (group_id >= FLAGS_heter_world_size) {
LOG(ERROR) << "group id exceed maxmium";
}
auto& local_shard = _local_shards[group_id]; auto& local_shard = _local_shards[group_id];
auto& request_io_buffer = cntl->request_attachment(); auto& request_io_buffer = cntl->request_attachment();
butil::IOBufBytesIterator io_buffer_itr(request_io_buffer); butil::IOBufBytesIterator io_buffer_itr(request_io_buffer);
...@@ -132,11 +135,11 @@ int SendAndRecvVariableHandler::SaveInSwitchWithShard( ...@@ -132,11 +135,11 @@ int SendAndRecvVariableHandler::SaveInSwitchWithShard(
const auto& var_name = request->send_var_names(idx); const auto& var_name = request->send_var_names(idx);
const auto& var_size = request->vars_len(idx); const auto& var_size = request->vars_len(idx);
WaitForVarsConsumed(group_id, var_name); WaitForVarsConsumed(group_id, var_name);
std::unique_lock<std::mutex> lk(scope_mutex_);
auto& value = local_shard[var_name]; auto& value = local_shard[var_name];
value.resize(var_size); value.resize(var_size);
io_buffer_itr.copy_and_forward(reinterpret_cast<void*>(value.data()), io_buffer_itr.copy_and_forward(reinterpret_cast<void*>(value.data()),
var_size); var_size);
std::unique_lock<std::mutex> lk(scope_mutex_);
vars_ready_flag[group_id][var_name] = 1; vars_ready_flag[group_id][var_name] = 1;
VLOG(4) << "saved var_name: " << var_name << "is saved ready!"; VLOG(4) << "saved var_name: " << var_name << "is saved ready!";
} }
...@@ -162,11 +165,11 @@ int SendAndRecvVariableHandler::QueryInSwitchWithShard( ...@@ -162,11 +165,11 @@ int SendAndRecvVariableHandler::QueryInSwitchWithShard(
VLOG(4) << "req var name: " << req_var_name; VLOG(4) << "req var name: " << req_var_name;
response->add_send_var_names(req_var_name); response->add_send_var_names(req_var_name);
WaitForVarsProduced(group_id, req_var_name); WaitForVarsProduced(group_id, req_var_name);
std::unique_lock<std::mutex> lk(scope_mutex_);
auto itr = local_shard.find(req_var_name); auto itr = local_shard.find(req_var_name);
auto& value = itr.value(); auto& value = itr.value();
response_io_buffer.append(value.data(), value.size()); response_io_buffer.append(value.data(), value.size());
value.resize(0); // 清空内存 value.resize(0); // 清空内存
std::unique_lock<std::mutex> lk(scope_mutex_);
vars_ready_flag[group_id][req_var_name] = 0; vars_ready_flag[group_id][req_var_name] = 0;
VLOG(4) << "query var_name: " << req_var_name << "is consumed ready!"; VLOG(4) << "query var_name: " << req_var_name << "is consumed ready!";
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册