diff --git a/paddle/fluid/operators/detail/grpc_server.cc b/paddle/fluid/operators/detail/grpc_server.cc index 57867aad4d679f75ea790b65b5773a73586fd96e..18651544a1c0207127be335c37fe85c6e24dc16e 100644 --- a/paddle/fluid/operators/detail/grpc_server.cc +++ b/paddle/fluid/operators/detail/grpc_server.cc @@ -41,11 +41,22 @@ class RequestBase { virtual ~RequestBase() {} virtual void Process() = 0; - CallStatus Status() { return status_; } - void SetStatus(CallStatus status) { status_ = status; } + CallStatus Status() const { + std::lock_guard l(status_mu_); + return status_; + } + + template + void Finish(const T& reply, ServerAsyncResponseWriter* responder) { + std::lock_guard l(status_mu_); + status_ = FINISH; + responder->Finish(reply, ::grpc::Status::OK, + reinterpret_cast(static_cast(req_id_))); + } virtual std::string GetReqName() = 0; protected: + mutable std::mutex status_mu_; ::grpc::ServerContext ctx_; GrpcService::AsyncService* service_; ::grpc::ServerCompletionQueue* cq_; @@ -80,9 +91,7 @@ class RequestSend final : public RequestBase { framework::Variable* outvar = nullptr; request_handler_->Handle(varname, scope, invar, &outvar); - status_ = FINISH; - responder_.Finish(reply_, ::grpc::Status::OK, - reinterpret_cast(static_cast(req_id_))); + Finish(reply_, &responder_); } protected: @@ -122,9 +131,7 @@ class RequestGet final : public RequestBase { SerializeToByteBuffer(varname, outvar, *request_handler_->dev_ctx(), &reply_); } - status_ = FINISH; - responder_.Finish(reply_, ::grpc::Status::OK, - reinterpret_cast(static_cast(req_id_))); + Finish(reply_, &responder_); } protected: @@ -166,9 +173,7 @@ class RequestPrefetch final : public RequestBase { SerializeToByteBuffer(varname, outvar, *request_handler_->dev_ctx(), &reply_); - responder_.Finish(reply_, ::grpc::Status::OK, - reinterpret_cast(static_cast(req_id_))); - status_ = FINISH; + Finish(reply_, &responder_); } protected: diff --git a/paddle/fluid/operators/detail/grpc_server.h b/paddle/fluid/operators/detail/grpc_server.h index e6ffc7066f24d5088a95801ed1c0670b24d5771f..f1db7590f6f14d5d44acc12453861a446e278cd2 100644 --- a/paddle/fluid/operators/detail/grpc_server.h +++ b/paddle/fluid/operators/detail/grpc_server.h @@ -53,6 +53,7 @@ class AsyncGRPCServer final : public RPCServer { void StartServer() override; private: + // HandleRequest needs to be thread-safe. void HandleRequest( ::grpc::ServerCompletionQueue* cq, const std::string& rpc_name, std::function TryToRegisterNewOne);