diff --git a/paddle/fluid/operators/detail/grpc_server.cc b/paddle/fluid/operators/detail/grpc_server.cc index 95f4738b4ff50852d9591719133ca650533bf848..7ca694886e9209a49e214352f5babc473a1f275a 100644 --- a/paddle/fluid/operators/detail/grpc_server.cc +++ b/paddle/fluid/operators/detail/grpc_server.cc @@ -82,7 +82,9 @@ class RequestSend final : public RequestBase { virtual std::string GetReqName() { return request_->Varname(); } virtual void Process() { - queue_->Push(std::make_pair(request_->Varname(), request_)); + std::string var_name = GetReqName(); + VLOG(3) << "RequestSend " << var_name; + queue_->Push(std::make_pair(var_name, request_)); sendrecv::VoidMessage reply; responder_.Finish(reply, ::grpc::Status::OK, this); @@ -106,7 +108,7 @@ class RequestGet final : public RequestBase { responder_(&ctx_), scope_(scope), queue_(queue) { - int method_id = static_cast(detail::GrpcMethod::kGetVariable); + auto method_id = static_cast(detail::GrpcMethod::kGetVariable); service_->RequestAsyncUnary(method_id, &ctx_, &request_, &responder_, cq_, cq_, this); } @@ -118,6 +120,7 @@ class RequestGet final : public RequestBase { virtual void Process() { // proc request. std::string var_name = request_.varname(); + VLOG(3) << "RequestGet " << var_name; auto* var = scope_->FindVar(var_name); ::grpc::ByteBuffer reply; @@ -176,7 +179,7 @@ class RequestPrefetch final : public RequestBase { ::grpc::ByteBuffer reply; std::string var_name = request_->OutVarname(); - VLOG(3) << "prefetch var " << var_name; + VLOG(3) << "RequestPrefetch " << var_name; auto var_desc = program_->Block(0).FindVar(var_name); framework::Scope* local_scope = &scope_->NewScope(); auto* var = local_scope->FindVar(var_name); @@ -307,18 +310,20 @@ void AsyncGRPCServer::HandleRequest(::grpc::ServerCompletionQueue* cq, bool ok = false; while (true) { - VLOG(3) << "HandleRequest for " << cq_name << " while in"; + VLOG(3) << "HandleRequest for " << cq_name << " wait Next"; if (!cq->Next(&tag, &ok)) { LOG(INFO) << cq_name << " CompletionQueue shutdown!"; break; } - VLOG(3) << "HandleRequest for " << cq_name << " while after Next"; + VLOG(3) << "HandleRequest for " << cq_name << " get Next"; PADDLE_ENFORCE(tag); + if (sync_mode_) { // FIXME(typhoonzero): de-couple the barriers with recv_op if (!is_shut_down_ && cq_name == "cq_get") WaitCond(1); if (!is_shut_down_ && cq_name == "cq_send") WaitCond(0); + VLOG(3) << "HandleRequest for " << cq_name << " after WaitCond"; } RequestBase* base = reinterpret_cast(tag); @@ -336,9 +341,9 @@ void AsyncGRPCServer::HandleRequest(::grpc::ServerCompletionQueue* cq, switch (base->Status()) { case PROCESS: { - VLOG(4) << cq_name << " PROCESS status:" << base->Status(); TryToRegisterNewOne(); base->Process(); + VLOG(4) << cq_name << " PROCESS status:" << base->Status(); break; } case FINISH: { diff --git a/paddle/fluid/operators/listen_and_serv_op.cc b/paddle/fluid/operators/listen_and_serv_op.cc index 57cff680ab89f2df7e71af4056ee06cdf330bbab..f22f8b261030c0c536e2118351ec2aa1a9be6cd0 100644 --- a/paddle/fluid/operators/listen_and_serv_op.cc +++ b/paddle/fluid/operators/listen_and_serv_op.cc @@ -45,20 +45,6 @@ static void split(const std::string &str, char sep, } } -static void AsyncExecuteBlock(framework::Executor *executor, - framework::ExecutorPrepareContext *prepared, - framework::Scope *scope) { - std::future future = framework::Async([&executor, &prepared, &scope]() { - try { - executor->RunPreparedContext(prepared, scope, false, false); - } catch (std::exception &e) { - LOG(ERROR) << "run sub program error " << e.what(); - } - }); - // TODO(qiao) maybe we can remove this - future.wait(); -} - static void ParallelExecuteBlocks( const std::vector ¶llel_blkids, framework::Executor *executor, const std::vector> @@ -201,14 +187,40 @@ void ListenAndServOp::RunSyncLoop(framework::Executor *executor, } // while(true) } +static void AsyncUpdateThread( + const std::string &var_name, const bool &exit_flag, + const std::shared_ptr &queue, + framework::Executor *executor, + framework::ExecutorPrepareContext *prepared) { + VLOG(3) << "update thread for " << var_name << " started"; + while (!exit_flag) { + const detail::ReceivedMessage v = queue->Pop(); + auto recv_var_name = v.first; + auto var = v.second->GetVar(); + if (var == nullptr) { + LOG(ERROR) << "Can not find server side var: " << recv_var_name; + PADDLE_THROW("Can not find server side var"); + } + auto fs = framework::Async([var_name, &executor, &v, prepared] { + try { + executor->RunPreparedContext(prepared, v.second->GetMutableLocalScope(), + false, false); + } catch (std::exception &e) { + LOG(ERROR) << "run sub program error " << e.what(); + } + }); + fs.wait(); + } +} + void ListenAndServOp::RunAsyncLoop(framework::Executor *executor, - framework::ProgramDesc *program, - framework::Scope *recv_scope, - framework::BlockDesc *prefetch_block) const { + framework::ProgramDesc *program) const { VLOG(3) << "RunAsyncLoop in"; // grad name to block id std::unordered_map grad_to_block_id; std::unordered_map id_to_grad; + std::unordered_map> + grad_to_queue; auto grad_to_block_id_str = Attr>("grad_to_block_id"); @@ -220,6 +232,7 @@ void ListenAndServOp::RunAsyncLoop(framework::Executor *executor, PADDLE_ENFORCE_EQ(grad_to_block_id.count(pieces[0]), 0); int block_id = std::stoi(pieces[1]); grad_to_block_id[pieces[0]] = block_id; + grad_to_queue[pieces[0]] = std::make_shared(); id_to_grad[block_id] = pieces[0]; } size_t num_blocks = program->Size(); @@ -238,8 +251,21 @@ void ListenAndServOp::RunAsyncLoop(framework::Executor *executor, grad_to_prepared_ctx[id_to_grad[block_list[i]]] = optimize_prepared[i]; } - VLOG(3) << "RunAsyncLoop into while"; bool exit_flag = false; + + VLOG(3) << "start async optimize threads"; + std::vector> fs; + for (auto iter = grad_to_queue.begin(); iter != grad_to_queue.end(); iter++) { + std::string grad_name = iter->first; + VLOG(3) << "create async update thread for " << grad_name; + fs.push_back(framework::AsyncIO([grad_name, &exit_flag, &executor, + &grad_to_queue, &grad_to_prepared_ctx]() { + AsyncUpdateThread(grad_name, exit_flag, grad_to_queue[grad_name], + executor, grad_to_prepared_ctx[grad_name].get()); + })); + } + + VLOG(3) << "RunAsyncLoop into while"; while (!exit_flag) { const detail::ReceivedMessage v = rpc_service_->Get(); auto recv_var_name = v.first; @@ -249,13 +275,7 @@ void ListenAndServOp::RunAsyncLoop(framework::Executor *executor, break; } else { VLOG(3) << "received grad: " << recv_var_name; - auto var = v.second->GetVar(); - if (var == nullptr) { - LOG(ERROR) << "Can not find server side var: " << recv_var_name; - PADDLE_THROW("Can not find server side var"); - } - AsyncExecuteBlock(executor, grad_to_prepared_ctx[recv_var_name].get(), - v.second->GetMutableLocalScope()); + grad_to_queue[recv_var_name]->Push(v); } if (exit_flag) { @@ -304,7 +324,7 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope, if (sync_mode) { RunSyncLoop(&executor, program, &recv_scope, prefetch_block); } else { - RunAsyncLoop(&executor, program, &recv_scope, prefetch_block); + RunAsyncLoop(&executor, program); } } diff --git a/paddle/fluid/operators/listen_and_serv_op.h b/paddle/fluid/operators/listen_and_serv_op.h index 3cc0f3047733bea94daa310cd39cb0a4f44bef85..5c8fc31c9774a0f2e8233824459b29b42469bd1a 100644 --- a/paddle/fluid/operators/listen_and_serv_op.h +++ b/paddle/fluid/operators/listen_and_serv_op.h @@ -47,9 +47,7 @@ class ListenAndServOp : public framework::OperatorBase { framework::BlockDesc* prefetch_block) const; void RunAsyncLoop(framework::Executor* executor, - framework::ProgramDesc* program, - framework::Scope* recv_scope, - framework::BlockDesc* prefetch_block) const; + framework::ProgramDesc* program) const; void Stop() override; diff --git a/python/paddle/fluid/layers/io.py b/python/paddle/fluid/layers/io.py index cc71c2136a6756ff094f6e06b8e200c6a68db06a..acfad45704d4ea9e28711c019db3563489aab3ff 100644 --- a/python/paddle/fluid/layers/io.py +++ b/python/paddle/fluid/layers/io.py @@ -168,7 +168,9 @@ class ListenAndServ(object): 'endpoint': self.endpoint, 'Fanin': self.fan_in, 'OptimizeBlock': current_block, - 'PrefetchBlock': empty_block + 'PrefetchBlock': empty_block, + 'sync_mode': True, # did not support async now in layers + 'grad_to_block_id': [""] })