diff --git a/paddle/operators/detail/grpc_server.cc b/paddle/operators/detail/grpc_server.cc index c0b94746a0b7f6ffb657bbf5af18360426933858..42d3cc57584d9dc546263e277b46e942c0d71f5d 100644 --- a/paddle/operators/detail/grpc_server.cc +++ b/paddle/operators/detail/grpc_server.cc @@ -36,7 +36,10 @@ class RequestBase { CallStatus Status() { return status_; } void SetStatus(CallStatus status) { status_ = status; } - virtual std::string GetReqName() { assert(false); } + virtual std::string GetReqName() { + assert(false); + return ""; + } protected: grpc::ServerContext ctx_; @@ -80,11 +83,13 @@ class RequestGet final : public RequestBase { public: explicit RequestGet(sendrecv::SendRecvService::AsyncService* service, grpc::ServerCompletionQueue* cq, framework::Scope* scope, - const platform::DeviceContext* dev_ctx) + const platform::DeviceContext* dev_ctx, + SimpleBlockQueue* queue) : RequestBase(service, cq), responder_(&ctx_), scope_(scope), - dev_ctx_(dev_ctx) { + dev_ctx_(dev_ctx), + queue_(queue) { service_->RequestGetVariable(&ctx_, &request_, &responder_, cq_, cq_, this); } @@ -100,6 +105,7 @@ class RequestGet final : public RequestBase { // TODO(gongwb): check var's info. responder_.Finish(reply_, grpc::Status::OK, this); status_ = FINISH; + queue_->Push('c'); } protected: @@ -108,8 +114,15 @@ class RequestGet final : public RequestBase { ServerAsyncResponseWriter responder_; framework::Scope* scope_; const platform::DeviceContext* dev_ctx_; + SimpleBlockQueue* queue_; }; +void AsyncGRPCServer::WaitClientGet(int count) { + for (int i = 0; i < count; ++i) { + var_get_queue_.Pop(); + } +} + void AsyncGRPCServer::RunSyncUpdate() { grpc::ServerBuilder builder; builder.AddListeningPort(address_, grpc::InsecureServerCredentials()); @@ -170,7 +183,8 @@ void AsyncGRPCServer::TryToRegisterNewGetOne() { if (is_shut_down_) { return; } - RequestGet* get = new RequestGet(&service_, cq_get_.get(), scope_, dev_ctx_); + RequestGet* get = new RequestGet(&service_, cq_get_.get(), scope_, dev_ctx_, + &var_get_queue_); VLOG(4) << "create Requestget status:" << get->Status(); } @@ -188,9 +202,8 @@ void AsyncGRPCServer::HandleRequest(bool wait, grpc::ServerCompletionQueue* cq, } PADDLE_ENFORCE(tag); - if (wait && !done_) { - Wait(); - } + if (cq_name == "cq_get") WaitCond(2); + if (cq_name == "cq_send") WaitCond(0); RequestBase* base = (RequestBase*)tag; // reference: @@ -222,22 +235,18 @@ void AsyncGRPCServer::HandleRequest(bool wait, grpc::ServerCompletionQueue* cq, } } -void AsyncGRPCServer::Wait() { - std::unique_lock lock(this->mutex_); - condition_.wait(lock, [=] { return this->done_ == true; }); -} - -void AsyncGRPCServer::Reset() { - std::lock_guard lock(this->mutex_); - done_ = false; +void AsyncGRPCServer::WaitCond(int cond) { + std::unique_lock lock(this->barrier_mutex_); + barrier_condition_.wait(lock, + [=] { return this->barrier_cond_step_ == cond; }); } -void AsyncGRPCServer::Done() { +void AsyncGRPCServer::SetCond(int cond) { { - std::lock_guard lock(this->mutex_); - done_ = true; + std::lock_guard lock(this->barrier_mutex_); + barrier_cond_step_ = cond; } - condition_.notify_all(); + barrier_condition_.notify_all(); } } // namespace detail diff --git a/paddle/operators/detail/grpc_server.h b/paddle/operators/detail/grpc_server.h index 2c078b77771656dc7fc0342ecf21b8d33dc11817..5c7be5f5bd2560aabb272a67e26ada1724454ad9 100644 --- a/paddle/operators/detail/grpc_server.h +++ b/paddle/operators/detail/grpc_server.h @@ -41,9 +41,12 @@ class AsyncGRPCServer final : public sendrecv::SendRecvService::Service { void RunSyncUpdate(); - void Reset(); - + // functions to sync server barrier status. + void WaitStart(); + void WaitDone(); + void Start(); void Done(); + void WaitClientGet(int count); void SetScope(framework::Scope *scope) { scope_ = scope; } @@ -56,7 +59,6 @@ class AsyncGRPCServer final : public sendrecv::SendRecvService::Service { void ShutDown(); protected: - void Wait(); void HandleRequest(bool wait, grpc::ServerCompletionQueue *cq, std::string cq_name, std::function TryToRegisterNewOne); @@ -78,11 +80,12 @@ class AsyncGRPCServer final : public sendrecv::SendRecvService::Service { const platform::DeviceContext *dev_ctx_; // received variable from RPC, operators fetch variable from this queue. SimpleBlockQueue var_recv_queue_; + SimpleBlockQueue var_get_queue_; // condition of the sub program - std::mutex mutex_; - volatile mutable bool done_; - std::condition_variable condition_; + std::mutex barrier_mutex_; + mutable int barrier_cond_step_; + std::condition_variable barrier_condition_; std::unique_ptr t_send_; std::unique_ptr t_get_; diff --git a/paddle/operators/recv_op.cc b/paddle/operators/recv_op.cc index b77d150dccfbed74b4a791c40a8e814a13161b09..2ecd56671f1c4081d3d188d3e4c1fc3a7ef878bb 100644 --- a/paddle/operators/recv_op.cc +++ b/paddle/operators/recv_op.cc @@ -34,6 +34,10 @@ limitations under the License. */ namespace paddle { namespace operators { +constexpr int kCondStart = 0; +constexpr int kCondRunning = 1; +constexpr int kCondDone = 2; + void RunServer(std::shared_ptr service) { service->RunSyncUpdate(); VLOG(4) << "RunServer thread end"; @@ -101,12 +105,14 @@ class RecvOp : public framework::OperatorBase { framework::ProgramDesc program(program_desc); framework::Executor executor(dev_place); - rpc_service_->Reset(); + // rpc_service_->Reset(); // TODO(typhoonzero): change this to a while_op for every cluster-batch. bool exit_flag = false; while (!exit_flag) { // Get from multiple trainers, we don't care about the order in which // the gradients arrives, just add suffix 0~n and merge the gradient. + rpc_service_->SetCond(kCondStart); + VLOG(3) << "================ start get from service ==========="; for (size_t i = 0; i < param_count * fan_in; ++i) { const detail::MessageWithName &v = rpc_service_->Get(); auto grad_var_name = v.first; @@ -139,15 +145,16 @@ class RecvOp : public framework::OperatorBase { if (exit_flag) { break; } - rpc_service_->Reset(); + // rpc_service_->Reset(); try { executor.Run(program, &recv_scope, 0, /*global_block*/ false /*create_local_scope*/, false /*create_vars*/); } catch (std::exception &e) { LOG(ERROR) << "run sub program error " << e.what(); } - - rpc_service_->Done(); + VLOG(3) << "================ run sub program end ==========="; + rpc_service_->SetCond(kCondDone); + rpc_service_->WaitClientGet(param_count * fan_in); grads_counter_.clear(); } // while(true) }