diff --git a/paddle/fluid/operators/listen_and_serv_op.cc b/paddle/fluid/operators/listen_and_serv_op.cc index ba2ea0d13e5a171df57ab1ff368d60ebe96571f7..616a89a4132a07d16ac9deae4ddbf7071ab9b1aa 100644 --- a/paddle/fluid/operators/listen_and_serv_op.cc +++ b/paddle/fluid/operators/listen_and_serv_op.cc @@ -45,20 +45,6 @@ static void split(const std::string &str, char sep, } } -static void AsyncExecuteBlock(framework::Executor *executor, - framework::ExecutorPrepareContext *prepared, - framework::Scope *scope) { - std::future future = framework::Async([&executor, &prepared, &scope]() { - try { - executor->RunPreparedContext(prepared, scope, false, false); - } catch (std::exception &e) { - LOG(ERROR) << "run sub program error " << e.what(); - } - }); - // TODO(qiao) maybe we can remove this - future.wait(); -} - static void ParallelExecuteBlocks( const std::vector ¶llel_blkids, framework::Executor *executor, const std::vector> @@ -201,14 +187,35 @@ void ListenAndServOp::RunSyncLoop(framework::Executor *executor, } // while(true) } +static void AsyncUpdateThread( + const bool &exit_flag, const std::shared_ptr &queue, + framework::Executor *executor, + framework::ExecutorPrepareContext *prepared) { + while (!exit_flag) { + const detail::ReceivedMessage v = queue->Pop(); + auto recv_var_name = v.first; + auto var = v.second->GetVar(); + if (var == nullptr) { + LOG(ERROR) << "Can not find server side var: " << recv_var_name; + PADDLE_THROW("Can not find server side var"); + } + try { + executor->RunPreparedContext(prepared, v.second->GetMutableLocalScope(), + false, false); + } catch (std::exception &e) { + LOG(ERROR) << "run sub program error " << e.what(); + } + } +} + void ListenAndServOp::RunAsyncLoop(framework::Executor *executor, - framework::ProgramDesc *program, - framework::Scope *recv_scope, - framework::BlockDesc *prefetch_block) const { + framework::ProgramDesc *program) const { VLOG(3) << "RunAsyncLoop in"; // grad name to block id std::unordered_map grad_to_block_id; std::unordered_map id_to_grad; + std::unordered_map> + grad_to_queue; auto grad_to_block_id_str = Attr>("grad_to_block_id"); @@ -220,6 +227,7 @@ void ListenAndServOp::RunAsyncLoop(framework::Executor *executor, PADDLE_ENFORCE_EQ(grad_to_block_id.count(pieces[0]), 0); int block_id = std::stoi(pieces[1]); grad_to_block_id[pieces[0]] = block_id; + grad_to_queue[pieces[0]] = std::make_shared(); id_to_grad[block_id] = pieces[0]; } size_t num_blocks = program->Size(); @@ -240,6 +248,18 @@ void ListenAndServOp::RunAsyncLoop(framework::Executor *executor, VLOG(3) << "RunAsyncLoop into while"; bool exit_flag = false; + + VLOG(3) << "start async optimize threads"; + std::vector> fs; + for (auto iter = grad_to_queue.begin(); iter != grad_to_queue.end(); iter++) { + std::string grad_name = iter->first; + fs.push_back(framework::Async([grad_name, &exit_flag, &executor, + &grad_to_queue, &grad_to_prepared_block]() { + AsyncUpdateThread(exit_flag, grad_to_queue[grad_name], executor, + grad_to_prepared_block[grad_name].get()); + })); + } + while (!exit_flag) { const detail::ReceivedMessage v = rpc_service_->Get(); auto recv_var_name = v.first; @@ -249,17 +269,7 @@ void ListenAndServOp::RunAsyncLoop(framework::Executor *executor, break; } else { VLOG(3) << "received grad: " << recv_var_name; - auto var = v.second->GetVar(); - if (var == nullptr) { - LOG(ERROR) << "Can not find server side var: " << recv_var_name; - PADDLE_THROW("Can not find server side var"); - } - AsyncExecuteBlock(executor, grad_to_prepared_block[recv_var_name].get(), - v.second->GetMutableLocalScope()); - // TODO(qiao): explain why - if (var->IsType()) { - var->GetMutable()->mutable_rows()->clear(); - } + grad_to_queue[recv_var_name]->Push(v); } if (exit_flag) { @@ -308,7 +318,7 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope, if (sync_mode) { RunSyncLoop(&executor, program, &recv_scope, prefetch_block); } else { - RunAsyncLoop(&executor, program, &recv_scope, prefetch_block); + RunAsyncLoop(&executor, program); } } diff --git a/paddle/fluid/operators/listen_and_serv_op.h b/paddle/fluid/operators/listen_and_serv_op.h index 3cc0f3047733bea94daa310cd39cb0a4f44bef85..5c8fc31c9774a0f2e8233824459b29b42469bd1a 100644 --- a/paddle/fluid/operators/listen_and_serv_op.h +++ b/paddle/fluid/operators/listen_and_serv_op.h @@ -47,9 +47,7 @@ class ListenAndServOp : public framework::OperatorBase { framework::BlockDesc* prefetch_block) const; void RunAsyncLoop(framework::Executor* executor, - framework::ProgramDesc* program, - framework::Scope* recv_scope, - framework::BlockDesc* prefetch_block) const; + framework::ProgramDesc* program) const; void Stop() override;