未验证 提交 1d198494 编写于 作者: X Xin Pan 提交者: GitHub

Merge pull request #11370 from panyx0718/dist

Make status update thread-safe
...@@ -41,11 +41,22 @@ class RequestBase { ...@@ -41,11 +41,22 @@ class RequestBase {
virtual ~RequestBase() {} virtual ~RequestBase() {}
virtual void Process() = 0; virtual void Process() = 0;
CallStatus Status() { return status_; } CallStatus Status() const {
void SetStatus(CallStatus status) { status_ = status; } std::lock_guard<std::mutex> l(status_mu_);
return status_;
}
template <typename T>
void Finish(const T& reply, ServerAsyncResponseWriter<T>* responder) {
std::lock_guard<std::mutex> l(status_mu_);
status_ = FINISH;
responder->Finish(reply, ::grpc::Status::OK,
reinterpret_cast<void*>(static_cast<intptr_t>(req_id_)));
}
virtual std::string GetReqName() = 0; virtual std::string GetReqName() = 0;
protected: protected:
mutable std::mutex status_mu_;
::grpc::ServerContext ctx_; ::grpc::ServerContext ctx_;
GrpcService::AsyncService* service_; GrpcService::AsyncService* service_;
::grpc::ServerCompletionQueue* cq_; ::grpc::ServerCompletionQueue* cq_;
...@@ -80,9 +91,7 @@ class RequestSend final : public RequestBase { ...@@ -80,9 +91,7 @@ class RequestSend final : public RequestBase {
framework::Variable* outvar = nullptr; framework::Variable* outvar = nullptr;
request_handler_->Handle(varname, scope, invar, &outvar); request_handler_->Handle(varname, scope, invar, &outvar);
status_ = FINISH; Finish(reply_, &responder_);
responder_.Finish(reply_, ::grpc::Status::OK,
reinterpret_cast<void*>(static_cast<intptr_t>(req_id_)));
} }
protected: protected:
...@@ -122,9 +131,7 @@ class RequestGet final : public RequestBase { ...@@ -122,9 +131,7 @@ class RequestGet final : public RequestBase {
SerializeToByteBuffer(varname, outvar, *request_handler_->dev_ctx(), SerializeToByteBuffer(varname, outvar, *request_handler_->dev_ctx(),
&reply_); &reply_);
} }
status_ = FINISH; Finish(reply_, &responder_);
responder_.Finish(reply_, ::grpc::Status::OK,
reinterpret_cast<void*>(static_cast<intptr_t>(req_id_)));
} }
protected: protected:
...@@ -166,9 +173,7 @@ class RequestPrefetch final : public RequestBase { ...@@ -166,9 +173,7 @@ class RequestPrefetch final : public RequestBase {
SerializeToByteBuffer(varname, outvar, *request_handler_->dev_ctx(), SerializeToByteBuffer(varname, outvar, *request_handler_->dev_ctx(),
&reply_); &reply_);
responder_.Finish(reply_, ::grpc::Status::OK, Finish(reply_, &responder_);
reinterpret_cast<void*>(static_cast<intptr_t>(req_id_)));
status_ = FINISH;
} }
protected: protected:
......
...@@ -53,6 +53,7 @@ class AsyncGRPCServer final : public RPCServer { ...@@ -53,6 +53,7 @@ class AsyncGRPCServer final : public RPCServer {
void StartServer() override; void StartServer() override;
private: private:
// HandleRequest needs to be thread-safe.
void HandleRequest( void HandleRequest(
::grpc::ServerCompletionQueue* cq, const std::string& rpc_name, ::grpc::ServerCompletionQueue* cq, const std::string& rpc_name,
std::function<void(const std::string&, int)> TryToRegisterNewOne); std::function<void(const std::string&, int)> TryToRegisterNewOne);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册