diff --git a/paddle/fluid/operators/detail/grpc_server.cc b/paddle/fluid/operators/detail/grpc_server.cc index c2c1df4cd6485933945077451464659d6b89f0f4..51ddda6255b8d0a95ed44d213235fe5fb1a0e1ce 100644 --- a/paddle/fluid/operators/detail/grpc_server.cc +++ b/paddle/fluid/operators/detail/grpc_server.cc @@ -25,6 +25,7 @@ namespace detail { namespace { const int kNumHandleSendThreads = 20; const int kNumHandleGetThreads = 20; +const int kNumHandlePrefetchThreads = 1; } // namespace enum CallStatus { PROCESS = 0, FINISH }; @@ -180,8 +181,9 @@ class RequestPrefetch final : public RequestBase { request_.reset(new VariableResponse(scope, dev_ctx_, true)); } int method_id = static_cast(detail::GrpcMethod::kPrefetchVariable); - service_->RequestAsyncUnary(method_id, &ctx_, request_.get(), &responder_, - cq_, cq_, this); + service_->RequestAsyncUnary( + method_id, &ctx_, request_.get(), &responder_, cq_, cq_, + reinterpret_cast(static_cast(req_id_))); } virtual ~RequestPrefetch() {} @@ -190,7 +192,6 @@ class RequestPrefetch final : public RequestBase { virtual void Process() { // prefetch process... - ::grpc::ByteBuffer reply; std::string var_name = request_->OutVarname(); VLOG(3) << "RequestPrefetch " << var_name; @@ -200,15 +201,16 @@ class RequestPrefetch final : public RequestBase { InitializeVariable(var, var_desc->GetType()); executor_->RunPreparedContext(prefetch_ctx_, scope_); - SerializeToByteBuffer(var_name, var, *dev_ctx_, &reply); + SerializeToByteBuffer(var_name, var, *dev_ctx_, &reply_); - responder_.Finish(reply, ::grpc::Status::OK, - reinterpret_cast(static_cast(req_id_))); status_ = FINISH; + responder_.Finish(reply_, ::grpc::Status::OK, + reinterpret_cast(static_cast(req_id_))); } protected: std::shared_ptr request_; + ::grpc::ByteBuffer reply_; ServerAsyncResponseWriter<::grpc::ByteBuffer> responder_; framework::Scope* scope_; framework::Executor* executor_; @@ -262,6 +264,9 @@ void AsyncGRPCServer::RunSyncUpdate() { for (int i = 0; i < kGetReqsBufSize; ++i) { TryToRegisterNewGetOne(i); } + for (int i = 0; i < kPrefetchReqsBufSize; ++i) { + TryToRegisterNewPrefetchOne(i); + } for (int i = 0; i < kNumHandleSendThreads; ++i) { t_sends_.emplace_back( @@ -273,12 +278,11 @@ void AsyncGRPCServer::RunSyncUpdate() { new std::thread(std::bind(&AsyncGRPCServer::HandleRequest, this, cq_get_.get(), "cq_get", get_register))); } - - // TODO(wuyi): Run these "HandleRequest" in thread pool - t_prefetch_.reset(new std::thread( - std::bind(&AsyncGRPCServer::HandleRequest, this, cq_prefetch_.get(), - "cq_prefetch", prefetch_register))); - + for (int i = 0; i < kNumHandlePrefetchThreads; ++i) { + t_prefetchs_.emplace_back(new std::thread( + std::bind(&AsyncGRPCServer::HandleRequest, this, cq_prefetch_.get(), + "cq_prefetch", prefetch_register))); + } { std::lock_guard lock(this->mutex_ready_); ready_ = 1; @@ -292,7 +296,9 @@ void AsyncGRPCServer::RunSyncUpdate() { for (int i = 0; i < kNumHandleGetThreads; ++i) { t_gets_[i]->join(); } - t_prefetch_->join(); + for (int i = 0; i < kNumHandlePrefetchThreads; ++i) { + t_prefetchs_[i]->join(); + } } void AsyncGRPCServer::ShutdownQueue() { @@ -342,6 +348,7 @@ void AsyncGRPCServer::TryToRegisterNewPrefetchOne(int req_id) { RequestPrefetch* prefetch = new RequestPrefetch( &service_, cq_prefetch_.get(), sync_mode_, scope_, dev_ctx_, executor_, program_, prefetch_ctx_.get(), req_id); + prefetch_reqs_[req_id] = static_cast(prefetch); VLOG(4) << "Create RequestPrefetch status:" << prefetch->Status(); } @@ -376,8 +383,8 @@ void AsyncGRPCServer::HandleRequest( base = get_reqs_[req_id]; } else if (cq_name == "cq_send") { base = send_reqs_[req_id]; - } else { - CHECK(false); + } else if (cq_name == "cq_prefetch") { + base = prefetch_reqs_[req_id]; } } // reference: diff --git a/paddle/fluid/operators/detail/grpc_server.h b/paddle/fluid/operators/detail/grpc_server.h index d70be1b7ce998a0faaaa256d62f6f1fcf5c267a0..9a60ee5579a6a50d913123d061dc43625ccc6013 100644 --- a/paddle/fluid/operators/detail/grpc_server.h +++ b/paddle/fluid/operators/detail/grpc_server.h @@ -93,6 +93,7 @@ class AsyncGRPCServer final { private: static const int kSendReqsBufSize = 100; static const int kGetReqsBufSize = 100; + static const int kPrefetchReqsBufSize = 10; std::mutex cq_mutex_; volatile bool is_shut_down_ = false; @@ -102,6 +103,7 @@ class AsyncGRPCServer final { RequestBase *send_reqs_[kSendReqsBufSize]; RequestBase *get_reqs_[kGetReqsBufSize]; + RequestBase *prefetch_reqs_[kPrefetchReqsBufSize]; GrpcService::AsyncService service_; std::unique_ptr<::grpc::Server> server_; @@ -123,6 +125,7 @@ class AsyncGRPCServer final { std::vector> t_sends_; std::vector> t_gets_; + std::vector> t_prefetchs_; std::unique_ptr t_prefetch_;