未验证 提交 bc8e600c 编写于 作者: C Chengmo 提交者: GitHub

Fix rpc not wait in GEO communicator (#20967)

* test=develop,fix rpc not wait in geo
上级 008ed65f
...@@ -937,8 +937,9 @@ void GeoSgdCommunicator::RpcSend(const std::string &origin_var_name, ...@@ -937,8 +937,9 @@ void GeoSgdCommunicator::RpcSend(const std::string &origin_var_name,
auto &cpu_ctx_send = *pool.Get(platform::CPUPlace()); auto &cpu_ctx_send = *pool.Get(platform::CPUPlace());
distributed::RPCClient *rpc_client = distributed::RPCClient *rpc_client =
distributed::RPCClient::GetInstance<RPCCLIENT_T>(trainer_id); distributed::RPCClient::GetInstance<RPCCLIENT_T>(trainer_id);
rpc_client->AsyncSendVar(endpoint, cpu_ctx_send, *delta_scope_.get(), auto handle = rpc_client->AsyncSendVar(endpoint, cpu_ctx_send,
splited_var_name); *delta_scope_.get(), splited_var_name);
handle->Wait();
} }
void GeoSgdCommunicator::RpcRecv(const std::string &var_name, void GeoSgdCommunicator::RpcRecv(const std::string &var_name,
...@@ -951,8 +952,10 @@ void GeoSgdCommunicator::RpcRecv(const std::string &var_name, ...@@ -951,8 +952,10 @@ void GeoSgdCommunicator::RpcRecv(const std::string &var_name,
distributed::RPCClient *rpc_client = distributed::RPCClient *rpc_client =
distributed::RPCClient::GetInstance<RPCCLIENT_T>(train_id); distributed::RPCClient::GetInstance<RPCCLIENT_T>(train_id);
pserver_scope_->Var(splited_var_name); pserver_scope_->Var(splited_var_name);
rpc_client->AsyncGetVar(endpoint, cpu_ctx_recv, *pserver_scope_.get(), auto handle = rpc_client->AsyncGetVar(endpoint, cpu_ctx_recv,
splited_var_name, splited_var_name, splited_var_name); *pserver_scope_.get(), splited_var_name,
splited_var_name, splited_var_name);
handle->Wait();
} }
void GeoSgdCommunicator::Recv() {} void GeoSgdCommunicator::Recv() {}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册