提交 37410a0c 编写于 作者: Y Yancey1989

update by comment

上级 029425a5
...@@ -47,19 +47,15 @@ Executor::Executor(const platform::Place& place) : place_(place) {} ...@@ -47,19 +47,15 @@ Executor::Executor(const platform::Place& place) : place_(place) {}
#ifdef PADDLE_WITH_DISTRIBUTE #ifdef PADDLE_WITH_DISTRIBUTE
void Executor::BeginPass() { void Executor::BeginPass() {
auto client = ::paddle::operators::distributed::RPCClient::GetInstance< ::paddle::operators::distributed::RPCClient::GetInstance<
::paddle::operators::distributed::GRPCClient>(); ::paddle::operators::distributed::GRPCClient>()
->SendBeginPass();
client->SendBeginPass();
client->Wait();
} }
void Executor::EndPass() { void Executor::EndPass() {
auto client = ::paddle::operators::distributed::RPCClient::GetInstance< ::paddle::operators::distributed::RPCClient::GetInstance<
::paddle::operators::distributed::GRPCClient>(); ::paddle::operators::distributed::GRPCClient>()
->SendEndPass();
client->SendEndPass();
client->Wait();
} }
#endif #endif
......
...@@ -40,6 +40,7 @@ void GRPCClient::SendBeginPass() { ...@@ -40,6 +40,7 @@ void GRPCClient::SendBeginPass() {
VLOG(3) << "send begin pass to: " << it.first; VLOG(3) << "send begin pass to: " << it.first;
this->AsyncSendBeginPass(it.first); this->AsyncSendBeginPass(it.first);
} }
this->Wait();
} }
void GRPCClient::SendEndPass() { void GRPCClient::SendEndPass() {
...@@ -47,6 +48,7 @@ void GRPCClient::SendEndPass() { ...@@ -47,6 +48,7 @@ void GRPCClient::SendEndPass() {
VLOG(3) << "send end pass to " << it.first; VLOG(3) << "send end pass to " << it.first;
this->AsyncSendEndPass(it.first); this->AsyncSendEndPass(it.first);
} }
this->Wait();
} }
GRPCClient::~GRPCClient() { GRPCClient::~GRPCClient() {
......
...@@ -67,7 +67,7 @@ void RPCServer::IncreaseBatchBarrier(const std::string rpc_name) { ...@@ -67,7 +67,7 @@ void RPCServer::IncreaseBatchBarrier(const std::string rpc_name) {
void RPCServer::BeginPass() { void RPCServer::BeginPass() {
VLOG(4) << "RPCServer begin increase pass barrier"; VLOG(4) << "RPCServer begin increase pass barrier";
{ {
std::unique_lock<std::mutex> locl(mutex_); std::unique_lock<std::mutex> lock(mutex_);
client_num_++; client_num_++;
VLOG(4) << "increase client_num to: " << client_num_; VLOG(4) << "increase client_num to: " << client_num_;
} }
...@@ -77,7 +77,7 @@ void RPCServer::BeginPass() { ...@@ -77,7 +77,7 @@ void RPCServer::BeginPass() {
void RPCServer::EndPass() { void RPCServer::EndPass() {
VLOG(4) << "RPCServer begin increase pass barrier"; VLOG(4) << "RPCServer begin increase pass barrier";
{ {
std::unique_lock<std::mutex> locl(mutex_); std::unique_lock<std::mutex> lock(mutex_);
client_num_--; client_num_--;
VLOG(4) << "decrease client_num to: " << client_num_; VLOG(4) << "decrease client_num to: " << client_num_;
if (cur_cond_.load() == rpc_cond_map_[kRequestGet]) { if (cur_cond_.load() == rpc_cond_map_[kRequestGet]) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册