diff --git a/paddle/fluid/operators/detail/grpc_server.cc b/paddle/fluid/operators/detail/grpc_server.cc index b92dc59491955ed319a2c2719244be93869a90c5..be14de9f5ed796f9bd914713d55679fbe16ea01c 100644 --- a/paddle/fluid/operators/detail/grpc_server.cc +++ b/paddle/fluid/operators/detail/grpc_server.cc @@ -30,9 +30,13 @@ enum CallStatus { PROCESS = 0, FINISH }; class RequestBase { public: explicit RequestBase(GrpcService::AsyncService* service, - ::grpc::ServerCompletionQueue* cq, + ::grpc::ServerCompletionQueue* cq, bool sync_mode, const platform::DeviceContext* dev_ctx) - : service_(service), cq_(cq), status_(PROCESS), dev_ctx_(dev_ctx) { + : service_(service), + cq_(cq), + sync_mode_(sync_mode), + status_(PROCESS), + dev_ctx_(dev_ctx) { PADDLE_ENFORCE(cq_); } virtual ~RequestBase() {} @@ -49,6 +53,7 @@ class RequestBase { ::grpc::ServerContext ctx_; GrpcService::AsyncService* service_; ::grpc::ServerCompletionQueue* cq_; + const bool sync_mode_; CallStatus status_; const platform::DeviceContext* dev_ctx_; }; @@ -56,11 +61,17 @@ class RequestBase { class RequestSend final : public RequestBase { public: explicit RequestSend(GrpcService::AsyncService* service, - ::grpc::ServerCompletionQueue* cq, + ::grpc::ServerCompletionQueue* cq, bool sync_mode, framework::Scope* scope, ReceivedQueue* queue, const platform::DeviceContext* dev_ctx) - : RequestBase(service, cq, dev_ctx), queue_(queue), responder_(&ctx_) { - request_.reset(new VariableResponse(false, scope, dev_ctx_)); + : RequestBase(service, cq, sync_mode, dev_ctx), + queue_(queue), + responder_(&ctx_) { + if (sync_mode_) { + request_.reset(new VariableResponse(false, scope, dev_ctx_)); + } else { + request_.reset(new VariableResponse(true, scope, dev_ctx_)); + } int method_id = static_cast(detail::GrpcMethod::kSendVariable); service_->RequestAsyncUnary(method_id, &ctx_, request_.get(), &responder_, cq_, cq_, this); @@ -87,11 +98,11 @@ class RequestSend final : public RequestBase { class RequestGet final : public RequestBase { public: explicit RequestGet(GrpcService::AsyncService* service, - ::grpc::ServerCompletionQueue* cq, + ::grpc::ServerCompletionQueue* cq, bool sync_mode, framework::Scope* scope, const platform::DeviceContext* dev_ctx, SimpleBlockQueue* queue) - : RequestBase(service, cq, dev_ctx), + : RequestBase(service, cq, sync_mode, dev_ctx), responder_(&ctx_), scope_(scope), queue_(queue) { @@ -134,19 +145,23 @@ class RequestGet final : public RequestBase { class RequestPrefetch final : public RequestBase { public: explicit RequestPrefetch(GrpcService::AsyncService* service, - ::grpc::ServerCompletionQueue* cq, + ::grpc::ServerCompletionQueue* cq, bool sync_mode, framework::Scope* scope, const platform::DeviceContext* dev_ctx, framework::Executor* executor, framework::ProgramDesc* program, framework::ExecutorPrepareContext* prefetch_ctx) - : RequestBase(service, cq, dev_ctx), + : RequestBase(service, cq, sync_mode, dev_ctx), responder_(&ctx_), scope_(scope), executor_(executor), program_(program), prefetch_ctx_(prefetch_ctx) { - request_.reset(new VariableResponse(false, scope, dev_ctx_)); + if (sync_mode_) { + request_.reset(new VariableResponse(false, scope, dev_ctx_)); + } else { + request_.reset(new VariableResponse(true, scope, dev_ctx_)); + } int method_id = static_cast(detail::GrpcMethod::kPrefetchVariable); service_->RequestAsyncUnary(method_id, &ctx_, request_.get(), &responder_, cq_, cq_, this); @@ -181,7 +196,6 @@ class RequestPrefetch final : public RequestBase { framework::Executor* executor_; framework::ProgramDesc* program_; framework::ExecutorPrepareContext* prefetch_ctx_; - int blkid_; }; void AsyncGRPCServer::WaitClientGet(int count) { @@ -254,8 +268,8 @@ void AsyncGRPCServer::TryToRegisterNewSendOne() { VLOG(3) << "shutdown, do not TryToRegisterNewSendOne"; return; } - RequestSend* send = new RequestSend(&service_, cq_send_.get(), scope_, - &var_recv_queue_, dev_ctx_); + RequestSend* send = new RequestSend(&service_, cq_send_.get(), sync_mode_, + scope_, &var_recv_queue_, dev_ctx_); VLOG(4) << "Create RequestSend status:" << send->Status(); } @@ -265,8 +279,8 @@ void AsyncGRPCServer::TryToRegisterNewGetOne() { VLOG(3) << "shutdown, do not TryToRegisterNewGetOne"; return; } - RequestGet* get = new RequestGet(&service_, cq_get_.get(), scope_, dev_ctx_, - &var_get_queue_); + RequestGet* get = new RequestGet(&service_, cq_get_.get(), sync_mode_, scope_, + dev_ctx_, &var_get_queue_); VLOG(4) << "Create RequestGet status:" << get->Status(); } @@ -277,8 +291,8 @@ void AsyncGRPCServer::TryToRegisterNewPrefetchOne() { return; } RequestPrefetch* prefetch = - new RequestPrefetch(&service_, cq_prefetch_.get(), scope_, dev_ctx_, - executor_, program_, prefetch_ctx_); + new RequestPrefetch(&service_, cq_prefetch_.get(), sync_mode_, scope_, + dev_ctx_, executor_, program_, prefetch_ctx_); VLOG(4) << "Create RequestPrefetch status:" << prefetch->Status(); } diff --git a/paddle/fluid/operators/detail/grpc_server.h b/paddle/fluid/operators/detail/grpc_server.h index 452ff5e967c086340e065a1b6a4b8672c75a4a3d..ae660ef480864ca6172ca1c597a38511066eb3a3 100644 --- a/paddle/fluid/operators/detail/grpc_server.h +++ b/paddle/fluid/operators/detail/grpc_server.h @@ -44,7 +44,8 @@ class RequestBase; class AsyncGRPCServer final { public: - explicit AsyncGRPCServer(const std::string &address) : address_(address) {} + explicit AsyncGRPCServer(const std::string &address, bool sync_mode) + : address_(address), sync_mode_(sync_mode) {} void RunSyncUpdate(); @@ -95,6 +96,7 @@ class AsyncGRPCServer final { std::unique_ptr<::grpc::Server> server_; std::string address_; + const bool sync_mode_; framework::Scope *scope_; const platform::DeviceContext *dev_ctx_; diff --git a/paddle/fluid/operators/detail/grpc_server_test.cc b/paddle/fluid/operators/detail/grpc_server_test.cc index c51933718f4ca78e87c77e007c485642000d247d..25b95d608d10d6e456d5f563ce9fbe35d812cb0f 100644 --- a/paddle/fluid/operators/detail/grpc_server_test.cc +++ b/paddle/fluid/operators/detail/grpc_server_test.cc @@ -89,7 +89,7 @@ void InitTensorsOnServer(framework::Scope* scope, platform::CPUPlace* place, } void StartServer(const std::string& endpoint) { - rpc_service_.reset(new detail::AsyncGRPCServer(endpoint)); + rpc_service_.reset(new detail::AsyncGRPCServer(endpoint, true)); framework::ProgramDesc program; framework::Scope scope; platform::CPUPlace place; diff --git a/paddle/fluid/operators/listen_and_serv_op.cc b/paddle/fluid/operators/listen_and_serv_op.cc index e5e447164f52b9bae149244ec4c1af48dc48e8d6..db8a0cd631b270bf89b49c9d7326e48669a0378c 100644 --- a/paddle/fluid/operators/listen_and_serv_op.cc +++ b/paddle/fluid/operators/listen_and_serv_op.cc @@ -278,7 +278,8 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope, PADDLE_ENFORCE(!rpc_service_); std::string endpoint = Attr("endpoint"); - rpc_service_.reset(new detail::AsyncGRPCServer(endpoint)); + + rpc_service_.reset(new detail::AsyncGRPCServer(endpoint, sync_mode)); auto *optimize_block = Attr(kOptimizeBlock); auto *prefetch_block = Attr(kPrefetchBlock);