From a0ced3df82e53ca9acc6db9827da5dfb172a7097 Mon Sep 17 00:00:00 2001 From: qiaolongfei Date: Mon, 23 Apr 2018 18:33:11 +0800 Subject: [PATCH] async update can run --- paddle/fluid/operators/detail/grpc_server.cc | 12 +++++++----- .../fluid/operators/detail/variable_response.h | 2 +- paddle/fluid/operators/listen_and_serv_op.cc | 18 +++++++++--------- paddle/fluid/operators/send_op.cc | 13 +++++++++---- python/paddle/fluid/distribute_transpiler.py | 12 +++++++----- 5 files changed, 33 insertions(+), 24 deletions(-) diff --git a/paddle/fluid/operators/detail/grpc_server.cc b/paddle/fluid/operators/detail/grpc_server.cc index 27ddb675009..60d7cc68fc0 100644 --- a/paddle/fluid/operators/detail/grpc_server.cc +++ b/paddle/fluid/operators/detail/grpc_server.cc @@ -315,9 +315,11 @@ void AsyncGRPCServer::HandleRequest(::grpc::ServerCompletionQueue* cq, VLOG(3) << "HandleRequest for " << cq_name << " while after Next"; PADDLE_ENFORCE(tag); - // FIXME(typhoonzero): de-couple the barriers with recv_op - if (!is_shut_down_ && cq_name == "cq_get") WaitCond(1); - if (!is_shut_down_ && cq_name == "cq_send") WaitCond(0); + if (sync_mode_) { + // FIXME(typhoonzero): de-couple the barriers with recv_op + if (!is_shut_down_ && cq_name == "cq_get") WaitCond(1); + if (!is_shut_down_ && cq_name == "cq_send") WaitCond(0); + } RequestBase* base = reinterpret_cast(tag); // reference: @@ -334,13 +336,13 @@ void AsyncGRPCServer::HandleRequest(::grpc::ServerCompletionQueue* cq, switch (base->Status()) { case PROCESS: { - VLOG(4) << cq_name << " status:" << base->Status(); + VLOG(4) << cq_name << " PROCESS status:" << base->Status(); TryToRegisterNewOne(); base->Process(); break; } case FINISH: { - VLOG(4) << cq_name << " status:" << base->Status(); + VLOG(4) << cq_name << " FINISH status:" << base->Status(); delete base; break; } diff --git a/paddle/fluid/operators/detail/variable_response.h b/paddle/fluid/operators/detail/variable_response.h index 3018a5c4af8..59a92f7155b 100644 --- a/paddle/fluid/operators/detail/variable_response.h +++ b/paddle/fluid/operators/detail/variable_response.h @@ -61,7 +61,7 @@ class VariableResponse { // other: number of error field. int Parse(const ::grpc::ByteBuffer& byte_buffer); - const framework::Scope& GetLocalScope() const { return *local_scope_; } + framework::Scope& GetLocalScope() const { return *local_scope_; } inline std::string Varname() { return meta_.varname(); } inline std::string OutVarname() { return meta_.out_varname(); } diff --git a/paddle/fluid/operators/listen_and_serv_op.cc b/paddle/fluid/operators/listen_and_serv_op.cc index a01f0aef8d4..bf351a2f520 100644 --- a/paddle/fluid/operators/listen_and_serv_op.cc +++ b/paddle/fluid/operators/listen_and_serv_op.cc @@ -48,13 +48,15 @@ static void split(const std::string &str, char sep, static void AsyncExecuteBlock(framework::Executor *executor, framework::ExecutorPrepareContext *prepared, framework::Scope *scope) { - framework::Async([&executor, &prepared, &scope]() { + std::future future = framework::Async([&executor, &prepared, &scope]() { try { executor->RunPreparedContext(prepared, scope, false, false); } catch (std::exception &e) { LOG(ERROR) << "run sub program error " << e.what(); } }); + // TODO(qiao) maybe we can remove this + future.wait(); } static void ParallelExecuteBlocks( @@ -203,6 +205,7 @@ void ListenAndServOp::RunAsyncLoop(framework::Executor *executor, framework::ProgramDesc *program, framework::Scope *recv_scope, framework::BlockDesc *prefetch_block) const { + VLOG(3) << "RunAsyncLoop in"; // grad name to block id std::unordered_map grad_to_id; std::unordered_map id_to_grad; @@ -210,7 +213,8 @@ void ListenAndServOp::RunAsyncLoop(framework::Executor *executor, auto grad_to_id_str = Attr>("grad_to_id"); for (auto &grad_and_id : grad_to_id_str) { std::vector pieces; - split(grad_and_id, ' ', &pieces); + split(grad_and_id, ':', &pieces); + VLOG(3) << "after split, grad = " << pieces[0] << ", id=" << pieces[1]; PADDLE_ENFORCE_EQ(pieces.size(), 2); PADDLE_ENFORCE_EQ(grad_to_id.count(pieces[0]), 0); int block_id = std::stoi(pieces[1]); @@ -223,14 +227,9 @@ void ListenAndServOp::RunAsyncLoop(framework::Executor *executor, std::vector block_list; for (size_t blkid = 1; blkid < num_blocks; ++blkid) { - if (blkid != static_cast(prefetch_block->ID())) { - block_list.push_back(blkid); - } + block_list.push_back(blkid); } - PADDLE_ENFORCE_EQ(grad_to_id_str.size(), block_list.size(), - "grad num should be equal to optimize block num"); auto optimize_prepared = executor->Prepare(*program, block_list); - std::unordered_map> grad_to_prepared; @@ -238,6 +237,7 @@ void ListenAndServOp::RunAsyncLoop(framework::Executor *executor, grad_to_prepared[id_to_grad[block_list[i]]] = optimize_prepared[i]; } + VLOG(3) << "RunAsyncLoop into while"; bool exit_flag = false; while (!exit_flag) { const detail::ReceivedMessage v = rpc_service_->Get(); @@ -254,7 +254,7 @@ void ListenAndServOp::RunAsyncLoop(framework::Executor *executor, PADDLE_THROW("Can not find server side var"); } AsyncExecuteBlock(executor, grad_to_prepared[recv_var_name].get(), - recv_scope); + &(v.second->GetLocalScope())); // TODO(qiao): explain why if (var->IsType()) { var->GetMutable()->mutable_rows()->clear(); diff --git a/paddle/fluid/operators/send_op.cc b/paddle/fluid/operators/send_op.cc index 82ff087d0a7..e4386b640a2 100644 --- a/paddle/fluid/operators/send_op.cc +++ b/paddle/fluid/operators/send_op.cc @@ -41,6 +41,8 @@ class SendOp : public framework::OperatorBase { std::vector endpoints = Attr>("endpoints"); + bool sync_mode = Attr("sync_mode"); + platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); auto& ctx = *pool.Get(place); @@ -64,11 +66,13 @@ class SendOp : public framework::OperatorBase { } PADDLE_ENFORCE(rpc_client->Wait()); - for (auto& ep : endpoints) { - VLOG(3) << "batch barrier, ep: " << ep; - rpc_client->AsyncSendBatchBarrier(ep); + if (sync_mode) { + for (auto& ep : endpoints) { + VLOG(3) << "batch barrier, ep: " << ep; + rpc_client->AsyncSendBatchBarrier(ep); + } + PADDLE_ENFORCE(rpc_client->Wait()); } - PADDLE_ENFORCE(rpc_client->Wait()); if (outs.size() > 0) { for (size_t i = 0; i < outs.size(); i++) { @@ -112,6 +116,7 @@ This operator will send tensor to recv_op at the parameter server. "Server endpoints in the order of input " "variables for mapping") .SetDefault({}); + AddAttr("sync_mode", "work in sync_mode or not").SetDefault(true); } }; diff --git a/python/paddle/fluid/distribute_transpiler.py b/python/paddle/fluid/distribute_transpiler.py index 6ac3b826719..3a3a94640a1 100644 --- a/python/paddle/fluid/distribute_transpiler.py +++ b/python/paddle/fluid/distribute_transpiler.py @@ -297,8 +297,11 @@ class DistributeTranspiler: inputs={"X": send_inputs}, outputs={"Out": send_outputs, "RPCClient": rpc_client_var}, - attrs={"endpoints": pserver_endpoints, - "epmap": eplist}) + attrs={ + "endpoints": pserver_endpoints, + "epmap": eplist, + "sync_mode": self.sync_mode + }) # step4: Concat the parameters splits together after recv. for varname, splited_var in param_var_mapping.iteritems(): if len(splited_var) <= 1: @@ -404,8 +407,8 @@ class DistributeTranspiler: for op in self.optimize_ops: if op.type == "scale": for in_name in op.input_arg_names: - if in_name.startswith("beta1_pow_acc") or\ - in_name.startswith("beta2_pow_acc"): + if in_name.startswith("beta1_pow_acc") or \ + in_name.startswith("beta2_pow_acc"): global_ops.append(op) def __append_optimize_op__(op, block, grad_to_block_id): @@ -434,7 +437,6 @@ class DistributeTranspiler: __append_optimize_op__(op, per_opt_block, grad_to_block_id) # append global ops - opt_state_block = None if global_ops: opt_state_block = pserver_program.create_block( pserver_program.num_blocks - 1) -- GitLab