From 260bf5aceba7ef4a54c144d7397ad470c48b18e4 Mon Sep 17 00:00:00 2001 From: qiaolongfei Date: Fri, 20 Apr 2018 18:22:35 +0800 Subject: [PATCH] add sync_mode --- paddle/fluid/operators/detail/grpc_server.cc | 48 ++++++++++++------- paddle/fluid/operators/detail/grpc_server.h | 4 +- .../operators/detail/grpc_server_test.cc | 2 +- paddle/fluid/operators/listen_and_serv_op.cc | 3 +- 4 files changed, 37 insertions(+), 20 deletions(-) diff --git a/paddle/fluid/operators/detail/grpc_server.cc b/paddle/fluid/operators/detail/grpc_server.cc index b92dc5949..be14de9f5 100644 --- a/paddle/fluid/operators/detail/grpc_server.cc +++ b/paddle/fluid/operators/detail/grpc_server.cc @@ -30,9 +30,13 @@ enum CallStatus { PROCESS = 0, FINISH }; class RequestBase { public: explicit RequestBase(GrpcService::AsyncService* service, - ::grpc::ServerCompletionQueue* cq, + ::grpc::ServerCompletionQueue* cq, bool sync_mode, const platform::DeviceContext* dev_ctx) - : service_(service), cq_(cq), status_(PROCESS), dev_ctx_(dev_ctx) { + : service_(service), + cq_(cq), + sync_mode_(sync_mode), + status_(PROCESS), + dev_ctx_(dev_ctx) { PADDLE_ENFORCE(cq_); } virtual ~RequestBase() {} @@ -49,6 +53,7 @@ class RequestBase { ::grpc::ServerContext ctx_; GrpcService::AsyncService* service_; ::grpc::ServerCompletionQueue* cq_; + const bool sync_mode_; CallStatus status_; const platform::DeviceContext* dev_ctx_; }; @@ -56,11 +61,17 @@ class RequestBase { class RequestSend final : public RequestBase { public: explicit RequestSend(GrpcService::AsyncService* service, - ::grpc::ServerCompletionQueue* cq, + ::grpc::ServerCompletionQueue* cq, bool sync_mode, framework::Scope* scope, ReceivedQueue* queue, const platform::DeviceContext* dev_ctx) - : RequestBase(service, cq, dev_ctx), queue_(queue), responder_(&ctx_) { - request_.reset(new VariableResponse(false, scope, dev_ctx_)); + : RequestBase(service, cq, sync_mode, dev_ctx), + queue_(queue), + responder_(&ctx_) { + if (sync_mode_) { + request_.reset(new VariableResponse(false, scope, dev_ctx_)); + } else { + request_.reset(new VariableResponse(true, scope, dev_ctx_)); + } int method_id = static_cast(detail::GrpcMethod::kSendVariable); service_->RequestAsyncUnary(method_id, &ctx_, request_.get(), &responder_, cq_, cq_, this); @@ -87,11 +98,11 @@ class RequestSend final : public RequestBase { class RequestGet final : public RequestBase { public: explicit RequestGet(GrpcService::AsyncService* service, - ::grpc::ServerCompletionQueue* cq, + ::grpc::ServerCompletionQueue* cq, bool sync_mode, framework::Scope* scope, const platform::DeviceContext* dev_ctx, SimpleBlockQueue* queue) - : RequestBase(service, cq, dev_ctx), + : RequestBase(service, cq, sync_mode, dev_ctx), responder_(&ctx_), scope_(scope), queue_(queue) { @@ -134,19 +145,23 @@ class RequestGet final : public RequestBase { class RequestPrefetch final : public RequestBase { public: explicit RequestPrefetch(GrpcService::AsyncService* service, - ::grpc::ServerCompletionQueue* cq, + ::grpc::ServerCompletionQueue* cq, bool sync_mode, framework::Scope* scope, const platform::DeviceContext* dev_ctx, framework::Executor* executor, framework::ProgramDesc* program, framework::ExecutorPrepareContext* prefetch_ctx) - : RequestBase(service, cq, dev_ctx), + : RequestBase(service, cq, sync_mode, dev_ctx), responder_(&ctx_), scope_(scope), executor_(executor), program_(program), prefetch_ctx_(prefetch_ctx) { - request_.reset(new VariableResponse(false, scope, dev_ctx_)); + if (sync_mode_) { + request_.reset(new VariableResponse(false, scope, dev_ctx_)); + } else { + request_.reset(new VariableResponse(true, scope, dev_ctx_)); + } int method_id = static_cast(detail::GrpcMethod::kPrefetchVariable); service_->RequestAsyncUnary(method_id, &ctx_, request_.get(), &responder_, cq_, cq_, this); @@ -181,7 +196,6 @@ class RequestPrefetch final : public RequestBase { framework::Executor* executor_; framework::ProgramDesc* program_; framework::ExecutorPrepareContext* prefetch_ctx_; - int blkid_; }; void AsyncGRPCServer::WaitClientGet(int count) { @@ -254,8 +268,8 @@ void AsyncGRPCServer::TryToRegisterNewSendOne() { VLOG(3) << "shutdown, do not TryToRegisterNewSendOne"; return; } - RequestSend* send = new RequestSend(&service_, cq_send_.get(), scope_, - &var_recv_queue_, dev_ctx_); + RequestSend* send = new RequestSend(&service_, cq_send_.get(), sync_mode_, + scope_, &var_recv_queue_, dev_ctx_); VLOG(4) << "Create RequestSend status:" << send->Status(); } @@ -265,8 +279,8 @@ void AsyncGRPCServer::TryToRegisterNewGetOne() { VLOG(3) << "shutdown, do not TryToRegisterNewGetOne"; return; } - RequestGet* get = new RequestGet(&service_, cq_get_.get(), scope_, dev_ctx_, - &var_get_queue_); + RequestGet* get = new RequestGet(&service_, cq_get_.get(), sync_mode_, scope_, + dev_ctx_, &var_get_queue_); VLOG(4) << "Create RequestGet status:" << get->Status(); } @@ -277,8 +291,8 @@ void AsyncGRPCServer::TryToRegisterNewPrefetchOne() { return; } RequestPrefetch* prefetch = - new RequestPrefetch(&service_, cq_prefetch_.get(), scope_, dev_ctx_, - executor_, program_, prefetch_ctx_); + new RequestPrefetch(&service_, cq_prefetch_.get(), sync_mode_, scope_, + dev_ctx_, executor_, program_, prefetch_ctx_); VLOG(4) << "Create RequestPrefetch status:" << prefetch->Status(); } diff --git a/paddle/fluid/operators/detail/grpc_server.h b/paddle/fluid/operators/detail/grpc_server.h index 452ff5e96..ae660ef48 100644 --- a/paddle/fluid/operators/detail/grpc_server.h +++ b/paddle/fluid/operators/detail/grpc_server.h @@ -44,7 +44,8 @@ class RequestBase; class AsyncGRPCServer final { public: - explicit AsyncGRPCServer(const std::string &address) : address_(address) {} + explicit AsyncGRPCServer(const std::string &address, bool sync_mode) + : address_(address), sync_mode_(sync_mode) {} void RunSyncUpdate(); @@ -95,6 +96,7 @@ class AsyncGRPCServer final { std::unique_ptr<::grpc::Server> server_; std::string address_; + const bool sync_mode_; framework::Scope *scope_; const platform::DeviceContext *dev_ctx_; diff --git a/paddle/fluid/operators/detail/grpc_server_test.cc b/paddle/fluid/operators/detail/grpc_server_test.cc index c51933718..25b95d608 100644 --- a/paddle/fluid/operators/detail/grpc_server_test.cc +++ b/paddle/fluid/operators/detail/grpc_server_test.cc @@ -89,7 +89,7 @@ void InitTensorsOnServer(framework::Scope* scope, platform::CPUPlace* place, } void StartServer(const std::string& endpoint) { - rpc_service_.reset(new detail::AsyncGRPCServer(endpoint)); + rpc_service_.reset(new detail::AsyncGRPCServer(endpoint, true)); framework::ProgramDesc program; framework::Scope scope; platform::CPUPlace place; diff --git a/paddle/fluid/operators/listen_and_serv_op.cc b/paddle/fluid/operators/listen_and_serv_op.cc index e5e447164..db8a0cd63 100644 --- a/paddle/fluid/operators/listen_and_serv_op.cc +++ b/paddle/fluid/operators/listen_and_serv_op.cc @@ -278,7 +278,8 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope, PADDLE_ENFORCE(!rpc_service_); std::string endpoint = Attr("endpoint"); - rpc_service_.reset(new detail::AsyncGRPCServer(endpoint)); + + rpc_service_.reset(new detail::AsyncGRPCServer(endpoint, sync_mode)); auto *optimize_block = Attr(kOptimizeBlock); auto *prefetch_block = Attr(kPrefetchBlock); -- GitLab