From 1509ae3a53adfc146053d5ae85ae55ab73f080f1 Mon Sep 17 00:00:00 2001 From: Xin Pan Date: Mon, 11 Jun 2018 20:20:59 +0800 Subject: [PATCH] Make status update thread-safe The status is updated in the Process() thread and can be checked in another HandleRequest() thread. --- paddle/fluid/operators/detail/grpc_server.cc | 27 ++++++++++++-------- paddle/fluid/operators/detail/grpc_server.h | 1 + 2 files changed, 17 insertions(+), 11 deletions(-) diff --git a/paddle/fluid/operators/detail/grpc_server.cc b/paddle/fluid/operators/detail/grpc_server.cc index 57867aad4..18651544a 100644 --- a/paddle/fluid/operators/detail/grpc_server.cc +++ b/paddle/fluid/operators/detail/grpc_server.cc @@ -41,11 +41,22 @@ class RequestBase { virtual ~RequestBase() {} virtual void Process() = 0; - CallStatus Status() { return status_; } - void SetStatus(CallStatus status) { status_ = status; } + CallStatus Status() const { + std::lock_guard l(status_mu_); + return status_; + } + + template + void Finish(const T& reply, ServerAsyncResponseWriter* responder) { + std::lock_guard l(status_mu_); + status_ = FINISH; + responder->Finish(reply, ::grpc::Status::OK, + reinterpret_cast(static_cast(req_id_))); + } virtual std::string GetReqName() = 0; protected: + mutable std::mutex status_mu_; ::grpc::ServerContext ctx_; GrpcService::AsyncService* service_; ::grpc::ServerCompletionQueue* cq_; @@ -80,9 +91,7 @@ class RequestSend final : public RequestBase { framework::Variable* outvar = nullptr; request_handler_->Handle(varname, scope, invar, &outvar); - status_ = FINISH; - responder_.Finish(reply_, ::grpc::Status::OK, - reinterpret_cast(static_cast(req_id_))); + Finish(reply_, &responder_); } protected: @@ -122,9 +131,7 @@ class RequestGet final : public RequestBase { SerializeToByteBuffer(varname, outvar, *request_handler_->dev_ctx(), &reply_); } - status_ = FINISH; - responder_.Finish(reply_, ::grpc::Status::OK, - reinterpret_cast(static_cast(req_id_))); + Finish(reply_, &responder_); } protected: @@ -166,9 +173,7 @@ class RequestPrefetch final : public RequestBase { SerializeToByteBuffer(varname, outvar, *request_handler_->dev_ctx(), &reply_); - responder_.Finish(reply_, ::grpc::Status::OK, - reinterpret_cast(static_cast(req_id_))); - status_ = FINISH; + Finish(reply_, &responder_); } protected: diff --git a/paddle/fluid/operators/detail/grpc_server.h b/paddle/fluid/operators/detail/grpc_server.h index e6ffc7066..f1db7590f 100644 --- a/paddle/fluid/operators/detail/grpc_server.h +++ b/paddle/fluid/operators/detail/grpc_server.h @@ -53,6 +53,7 @@ class AsyncGRPCServer final : public RPCServer { void StartServer() override; private: + // HandleRequest needs to be thread-safe. void HandleRequest( ::grpc::ServerCompletionQueue* cq, const std::string& rpc_name, std::function TryToRegisterNewOne); -- GitLab