From 5f4d9130f01833dfef44dac2eadb7089accbe0ba Mon Sep 17 00:00:00 2001 From: typhoonzero Date: Thu, 18 Jan 2018 19:27:20 +0800 Subject: [PATCH] merge codes --- paddle/operators/detail/grpc_server.cc | 5 +++-- paddle/operators/detail/grpc_server.h | 6 ++---- paddle/operators/recv_op.cc | 15 +++++---------- 3 files changed, 10 insertions(+), 16 deletions(-) diff --git a/paddle/operators/detail/grpc_server.cc b/paddle/operators/detail/grpc_server.cc index 42d3cc575..3ddcd839b 100644 --- a/paddle/operators/detail/grpc_server.cc +++ b/paddle/operators/detail/grpc_server.cc @@ -162,7 +162,6 @@ void AsyncGRPCServer::ShutdownQueue() { } // This URL explains why shutdown is complicate: -// https://stackoverflow.com/questions/35708348/grpc-what-is-the-recommended-way-to-shut-down-an-asynchronous-server-in-c void AsyncGRPCServer::ShutDown() { server_->Shutdown(); ShutdownQueue(); @@ -188,6 +187,7 @@ void AsyncGRPCServer::TryToRegisterNewGetOne() { VLOG(4) << "create Requestget status:" << get->Status(); } +// FIXME(typhoonzero): remove wait argument and change cq_name to enum. void AsyncGRPCServer::HandleRequest(bool wait, grpc::ServerCompletionQueue* cq, std::string cq_name, std::function TryToRegisterNewOne) { @@ -202,7 +202,8 @@ void AsyncGRPCServer::HandleRequest(bool wait, grpc::ServerCompletionQueue* cq, } PADDLE_ENFORCE(tag); - if (cq_name == "cq_get") WaitCond(2); + // FIXME(typhoonzero): de-couple the barriers with recv_op + if (cq_name == "cq_get") WaitCond(1); if (cq_name == "cq_send") WaitCond(0); RequestBase* base = (RequestBase*)tag; diff --git a/paddle/operators/detail/grpc_server.h b/paddle/operators/detail/grpc_server.h index 5c7be5f5b..1ca9086c7 100644 --- a/paddle/operators/detail/grpc_server.h +++ b/paddle/operators/detail/grpc_server.h @@ -42,10 +42,8 @@ class AsyncGRPCServer final : public sendrecv::SendRecvService::Service { void RunSyncUpdate(); // functions to sync server barrier status. - void WaitStart(); - void WaitDone(); - void Start(); - void Done(); + void WaitCond(int cond); + void SetCond(int cond); void WaitClientGet(int count); void SetScope(framework::Scope *scope) { scope_ = scope; } diff --git a/paddle/operators/recv_op.cc b/paddle/operators/recv_op.cc index 2ecd56671..8d1479bdd 100644 --- a/paddle/operators/recv_op.cc +++ b/paddle/operators/recv_op.cc @@ -105,15 +105,14 @@ class RecvOp : public framework::OperatorBase { framework::ProgramDesc program(program_desc); framework::Executor executor(dev_place); - // rpc_service_->Reset(); // TODO(typhoonzero): change this to a while_op for every cluster-batch. bool exit_flag = false; + int64_t barrier_size = param_count * fan_in; while (!exit_flag) { // Get from multiple trainers, we don't care about the order in which // the gradients arrives, just add suffix 0~n and merge the gradient. - rpc_service_->SetCond(kCondStart); - VLOG(3) << "================ start get from service ==========="; - for (size_t i = 0; i < param_count * fan_in; ++i) { + rpc_service_->SetCond(0); + for (size_t i = 0; i < barrier_size; ++i) { const detail::MessageWithName &v = rpc_service_->Get(); auto grad_var_name = v.first; if (grad_var_name == LISTEN_TERMINATE_MESSAGE) { @@ -130,8 +129,6 @@ class RecvOp : public framework::OperatorBase { } VLOG(3) << "recved grad: " << grad_var_name << " updating param: " << param_var_name; - // Assume grad_var_name must appear in global scope. - std::string grad_var_name_trainer; if (fan_in > 1) { grad_var_name = this->GetGradVarNameForTrainer(grad_var_name); } @@ -145,16 +142,14 @@ class RecvOp : public framework::OperatorBase { if (exit_flag) { break; } - // rpc_service_->Reset(); try { executor.Run(program, &recv_scope, 0, /*global_block*/ false /*create_local_scope*/, false /*create_vars*/); } catch (std::exception &e) { LOG(ERROR) << "run sub program error " << e.what(); } - VLOG(3) << "================ run sub program end ==========="; - rpc_service_->SetCond(kCondDone); - rpc_service_->WaitClientGet(param_count * fan_in); + rpc_service_->SetCond(1); + rpc_service_->WaitClientGet(barrier_size); grads_counter_.clear(); } // while(true) } -- GitLab