未验证 提交 5f89ce7f 编写于 作者: 乔龙飞 Qiao Longfei 提交者: GitHub

Merge pull request #15536 from jacquesqiao/fix-prefetch-one-parameter

Fix prefetch one parameter
...@@ -54,6 +54,11 @@ bool RequestSendHandler::Handle(const std::string& varname, ...@@ -54,6 +54,11 @@ bool RequestSendHandler::Handle(const std::string& varname,
// Async // Async
if (!sync_mode_) { if (!sync_mode_) {
VLOG(3) << "async process var: " << varname; VLOG(3) << "async process var: " << varname;
if (varname == BATCH_BARRIER_MESSAGE) {
PADDLE_THROW(
"async mode should not recv BATCH_BARRIER_MESSAGE or "
"COMPLETE_MESSAGE");
}
try { try {
executor_->RunPreparedContext((*grad_to_prepared_ctx_)[varname].get(), executor_->RunPreparedContext((*grad_to_prepared_ctx_)[varname].get(),
scope); scope);
......
...@@ -39,27 +39,33 @@ void RPCServer::SavePort() const { ...@@ -39,27 +39,33 @@ void RPCServer::SavePort() const {
port_file.open(file_path); port_file.open(file_path);
port_file << selected_port_; port_file << selected_port_;
port_file.close(); port_file.close();
VLOG(4) << "selected port written to " << file_path; VLOG(3) << "selected port written to " << file_path;
} }
void RPCServer::WaitBarrier(const std::string& rpc_name) { void RPCServer::WaitBarrier(const std::string& rpc_name) {
VLOG(3) << "WaitBarrier in: " << rpc_name;
std::unique_lock<std::mutex> lock(this->mutex_); std::unique_lock<std::mutex> lock(this->mutex_);
barrier_cond_.wait(lock, [this, &rpc_name] { barrier_cond_.wait(lock, [this, &rpc_name] {
return ((barrier_counter_[rpc_name] == client_num_ && client_num_ != 0) || return ((barrier_counter_[rpc_name] == client_num_ && client_num_ != 0) ||
exit_flag_.load()); exit_flag_.load());
}); });
VLOG(3) << "batch_barrier_: " << rpc_name << " " VLOG(3) << "WaitBarrier out: " << rpc_name
<< barrier_counter_[rpc_name]; << " counter: " << barrier_counter_[rpc_name];
} }
void RPCServer::IncreaseBatchBarrier(const std::string rpc_name) { void RPCServer::IncreaseBatchBarrier(const std::string rpc_name) {
VLOG(4) << "RPCServer begin IncreaseBatchBarrier " << rpc_name; VLOG(3) << "RPCServer begin IncreaseBatchBarrier " << rpc_name;
// barrier msg should make sure that it's in the right cond(send|recv)
WaitCond(rpc_name);
int b = 0; int b = 0;
std::unique_lock<std::mutex> lock(mutex_); std::unique_lock<std::mutex> lock(mutex_);
b = ++barrier_counter_[rpc_name]; b = ++barrier_counter_[rpc_name];
VLOG(3) << rpc_name << " barrier_counter: " << b;
if (b >= client_num_) { if (b >= client_num_) {
lock.unlock(); lock.unlock();
VLOG(3) << "BatchBarrier counter reach " << client_num_ << " for "
<< rpc_name;
barrier_cond_.notify_all(); barrier_cond_.notify_all();
lock.lock(); lock.lock();
} }
...@@ -71,7 +77,7 @@ void RPCServer::Complete() { ...@@ -71,7 +77,7 @@ void RPCServer::Complete() {
client_num_--; client_num_--;
need_reset_all_vars_ = true; need_reset_all_vars_ = true;
VLOG(4) << "decrease client_num to: " << client_num_; VLOG(3) << "decrease client_num to: " << client_num_;
if (cur_cond_.load() == rpc_cond_map_[kRequestGet]) { if (cur_cond_.load() == rpc_cond_map_[kRequestGet]) {
barrier_counter_[kRequestGet]--; barrier_counter_[kRequestGet]--;
} }
...@@ -105,8 +111,8 @@ void RPCServer::RegisterRPC(const std::string& rpc_name, ...@@ -105,8 +111,8 @@ void RPCServer::RegisterRPC(const std::string& rpc_name,
static int cond = -1; static int cond = -1;
rpc_cond_map_[rpc_name] = ++cond; rpc_cond_map_[rpc_name] = ++cond;
VLOG(4) << "RegisterRPC rpc_name:" << rpc_name << ", handler:" << handler VLOG(3) << "RegisterRPC rpc_name: " << rpc_name << ", handler: " << handler
<< ", cond:" << rpc_cond_map_[rpc_name]; << ", cond: " << rpc_cond_map_[rpc_name];
} }
void RPCServer::SetCond(const std::string& rpc_name) { void RPCServer::SetCond(const std::string& rpc_name) {
...@@ -120,7 +126,7 @@ void RPCServer::SetCond(const std::string& rpc_name) { ...@@ -120,7 +126,7 @@ void RPCServer::SetCond(const std::string& rpc_name) {
} }
void RPCServer::WaitCond(const std::string& rpc_name) { void RPCServer::WaitCond(const std::string& rpc_name) {
VLOG(4) << "RPCServer WaitCond " << rpc_name; VLOG(3) << "RPCServer WaitCond in " << rpc_name;
int cond = 0; int cond = 0;
{ {
std::unique_lock<std::mutex> lock(mutex_); std::unique_lock<std::mutex> lock(mutex_);
...@@ -130,6 +136,7 @@ void RPCServer::WaitCond(const std::string& rpc_name) { ...@@ -130,6 +136,7 @@ void RPCServer::WaitCond(const std::string& rpc_name) {
std::unique_lock<std::mutex> lock(mutex_); std::unique_lock<std::mutex> lock(mutex_);
rpc_cond_.wait( rpc_cond_.wait(
lock, [=] { return (cur_cond_.load() == cond || exit_flag_.load()); }); lock, [=] { return (cur_cond_.load() == cond || exit_flag_.load()); });
VLOG(3) << "RPCServer WaitCond out " << rpc_name;
} }
void RPCServer::RegisterVar(const std::string& var_name, void RPCServer::RegisterVar(const std::string& var_name,
...@@ -151,7 +158,7 @@ void RPCServer::RegisterVar(const std::string& var_name, ...@@ -151,7 +158,7 @@ void RPCServer::RegisterVar(const std::string& var_name,
} }
rpc_cond_.notify_all(); rpc_cond_.notify_all();
VLOG(4) << "RegisterVar context:" << h.String(); VLOG(3) << "RegisterVar context:" << h.String();
} }
void RPCServer::IncreaseVarBarrier(const std::string& var_name) { void RPCServer::IncreaseVarBarrier(const std::string& var_name) {
...@@ -167,11 +174,11 @@ void RPCServer::IncreaseVarBarrier(const std::string& var_name) { ...@@ -167,11 +174,11 @@ void RPCServer::IncreaseVarBarrier(const std::string& var_name) {
barrier_cond_.notify_all(); barrier_cond_.notify_all();
} }
VLOG(4) << "IncreaseVarBarrier context:" << h.String(); VLOG(3) << "IncreaseVarBarrier context:" << h.String();
} }
void RPCServer::WaitVarBarrier(const std::string& var_name) { void RPCServer::WaitVarBarrier(const std::string& var_name) {
VLOG(4) << "WaitBarrier var_name:" << var_name; VLOG(3) << "WaitVarBarrier var_name:" << var_name;
std::unique_lock<std::mutex> lock(mutex_); std::unique_lock<std::mutex> lock(mutex_);
barrier_cond_.wait(lock, [&]() { barrier_cond_.wait(lock, [&]() {
...@@ -179,11 +186,11 @@ void RPCServer::WaitVarBarrier(const std::string& var_name) { ...@@ -179,11 +186,11 @@ void RPCServer::WaitVarBarrier(const std::string& var_name) {
exit_flag_.load()); exit_flag_.load());
}); });
VLOG(4) << "WaitBarrier context: " << var_map_[var_name].String(); VLOG(3) << "WaitVarBarrier context: " << var_map_[var_name].String();
} }
void RPCServer::SetVarCond(const std::string& var_name) { void RPCServer::SetVarCond(const std::string& var_name) {
VLOG(4) << "SetVarCond var_name:" << var_name; VLOG(3) << "SetVarCond var_name:" << var_name;
{ {
std::unique_lock<std::mutex> lock(mutex_); std::unique_lock<std::mutex> lock(mutex_);
if (var_map_.find(var_name) != var_map_.end()) { if (var_map_.find(var_name) != var_map_.end()) {
...@@ -193,14 +200,14 @@ void RPCServer::SetVarCond(const std::string& var_name) { ...@@ -193,14 +200,14 @@ void RPCServer::SetVarCond(const std::string& var_name) {
} }
void RPCServer::WaitVarCond(const std::string& var_name) { void RPCServer::WaitVarCond(const std::string& var_name) {
VLOG(4) << "WaitVarCond var_name:" << var_name; VLOG(3) << "WaitVarCond var_name:" << var_name;
std::unique_lock<std::mutex> lock(mutex_); std::unique_lock<std::mutex> lock(mutex_);
rpc_cond_.wait(lock, [=] { rpc_cond_.wait(lock, [=] {
return (var_map_.find(var_name) != var_map_.end() || exit_flag_.load()); return (var_map_.find(var_name) != var_map_.end() || exit_flag_.load());
}); });
VLOG(4) << "WaitVarCond var_name:" << var_name << " end"; VLOG(3) << "WaitVarCond var_name:" << var_name << " end";
} }
MonomerHandle RPCServer::GetMonomer(const std::string& var_name) { MonomerHandle RPCServer::GetMonomer(const std::string& var_name) {
......
...@@ -137,7 +137,9 @@ void ListenAndServOp::RunSyncLoop( ...@@ -137,7 +137,9 @@ void ListenAndServOp::RunSyncLoop(
while (true) { while (true) {
// Get from multiple trainers, we don't care about the order in which // Get from multiple trainers, we don't care about the order in which
// the gradients arrives, just add suffix 0~n and merge the gradient. // the gradients arrives, just add suffix 0~n and merge the gradient.
VLOG(3) << "wait all clients to send gradient";
rpc_service_->SetCond(distributed::kRequestSend); rpc_service_->SetCond(distributed::kRequestSend);
VLOG(3) << "wait all clients to send send_barrier";
rpc_service_->WaitBarrier(distributed::kRequestSend); rpc_service_->WaitBarrier(distributed::kRequestSend);
if (rpc_service_->IsExit()) { if (rpc_service_->IsExit()) {
...@@ -168,12 +170,16 @@ void ListenAndServOp::RunSyncLoop( ...@@ -168,12 +170,16 @@ void ListenAndServOp::RunSyncLoop(
} }
ParallelExecuteBlocks(parallel_blkids, executor, optimize_prepared, program, ParallelExecuteBlocks(parallel_blkids, executor, optimize_prepared, program,
recv_scope); recv_scope);
VLOG(2) << "run all blocks spent " << GetTimestamp() - ts << "(ms)"; VLOG(3) << "run all blocks spent " << GetTimestamp() - ts << "(ms)";
VLOG(3) << "ResetReceivedVars";
ResetReceivedVars(recv_scope, dev_ctx, rpc_service_->NeedResetAllVars()); ResetReceivedVars(recv_scope, dev_ctx, rpc_service_->NeedResetAllVars());
VLOG(3) << "wait all clients to get parameters back";
rpc_service_->SetCond(distributed::kRequestGet); rpc_service_->SetCond(distributed::kRequestGet);
VLOG(3) << "wait all clients to send fetch_barrier";
rpc_service_->WaitBarrier(distributed::kRequestGet); rpc_service_->WaitBarrier(distributed::kRequestGet);
VLOG(3) << "ResetBarrierCounter";
rpc_service_->ResetBarrierCounter(); rpc_service_->ResetBarrierCounter();
} // while(true) } // while(true)
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册