diff --git a/paddle/operators/detail/grpc_client.cc b/paddle/operators/detail/grpc_client.cc index 1e41587c418fb0ce4e452d5c6735c54e2d42f798..d699dabf2fb982f267c4869180efaf0e600eb46c 100644 --- a/paddle/operators/detail/grpc_client.cc +++ b/paddle/operators/detail/grpc_client.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "grpc_client.h" +#include "paddle/framework/threadpool.h" namespace paddle { namespace operators { namespace detail { @@ -22,25 +23,32 @@ bool RPCClient::AsyncSendVariable(const std::string& ep, const framework::Scope& scope, const std::string& var_name, int64_t time_out) { - sendrecv::VariableMessage req; - auto* var = scope.FindVar(var_name); - SerializeToMessage(var_name, var, ctx, &req); - - // varhandle - VarHandle var_h; - var_h.ep = ep; - var_h.scope = &scope; - var_h.name = var_name; - var_h.ctx = &ctx; - - // stub context - auto ch = GetChannel(ep); - SendProcessor* s = new SendProcessor(ch); - s->Prepare(var_h, time_out); - s->response_call_back_ = NULL; - - auto rpc = s->stub_->AsyncSendVariable(s->context_.get(), req, &cq_); - rpc->Finish(&s->reply_, &s->status_, (void*)s); + const platform::DeviceContext* p_ctx = &ctx; + const std::string ep_val = ep; + const std::string var_name_val = var_name; + const framework::Scope* p_scope = &scope; + const auto ch = GetChannel(ep_val); + + framework::Async([var_name_val, p_ctx, ep_val, p_scope, time_out, ch, this] { + auto* var = p_scope->FindVar(var_name_val); + sendrecv::VariableMessage req; + SerializeToMessage(var_name_val, var, *p_ctx, &req); + + // varhandle + VarHandle var_h; + var_h.ep = ep_val; + var_h.scope = p_scope; + var_h.name = var_name_val; + var_h.ctx = p_ctx; + + // stub context + SendProcessor* s = new SendProcessor(ch); + s->Prepare(var_h, time_out); + s->response_call_back_ = NULL; + + auto rpc = s->stub_->AsyncSendVariable(s->context_.get(), req, &cq_); + rpc->Finish(&s->reply_, &s->status_, (void*)s); + }); req_count_++; @@ -50,8 +58,6 @@ bool RPCClient::AsyncSendVariable(const std::string& ep, void ProcGetResponse(const VarHandle& var_h, const sendrecv::VariableMessage& ret_msg) { auto* outvar = var_h.scope->FindVar(var_h.name); - - std::istringstream iss(ret_msg.serialized()); DeserializeFromMessage(ret_msg, *var_h.ctx, outvar); } @@ -60,24 +66,31 @@ bool RPCClient::AsyncGetVariable(const std::string& ep, const framework::Scope& scope, const std::string& var_name, int64_t time_out) { - sendrecv::VariableMessage req; - req.set_varname(var_name); - - // varhandle - VarHandle var_h; - var_h.ep = ep; - var_h.scope = &scope; - var_h.name = var_name; - var_h.ctx = &ctx; - - // stub context - auto ch = GetChannel(ep); - GetProcessor* s = new GetProcessor(ch); - s->Prepare(var_h, time_out); - s->response_call_back_ = ProcGetResponse; - - auto rpc = s->stub_->AsyncGetVariable(s->context_.get(), req, &cq_); - rpc->Finish(&s->reply_, &s->status_, (void*)s); + const platform::DeviceContext* p_ctx = &ctx; + const std::string ep_val = ep; + const std::string var_name_val = var_name; + const framework::Scope* p_scope = &scope; + const auto ch = GetChannel(ep_val); + + framework::Async([var_name_val, ep_val, p_scope, p_ctx, time_out, ch, this] { + sendrecv::VariableMessage req; + req.set_varname(var_name_val); + + // varhandle + VarHandle var_h; + var_h.ep = ep_val; + var_h.scope = p_scope; + var_h.name = var_name_val; + var_h.ctx = p_ctx; + + // stub context + GetProcessor* s = new GetProcessor(ch); + s->Prepare(var_h, time_out); + s->response_call_back_ = ProcGetResponse; + + auto rpc = s->stub_->AsyncGetVariable(s->context_.get(), req, &cq_); + rpc->Finish(&s->reply_, &s->status_, (void*)s); + }); req_count_++; @@ -85,19 +98,31 @@ bool RPCClient::AsyncGetVariable(const std::string& ep, } bool RPCClient::Wait() { - bool ok = true; + if (req_count_ <= 0) { + return true; + } - while (true) { - if (req_count_ <= 0) { - break; - } + std::vector a(req_count_); + std::vector> waits(req_count_); - if (!Proceed()) { + for (int i = 0; i < req_count_; i++) { + waits[i] = framework::Async([i, &a, this] { a[i] = Proceed(); }); + } + + for (int i = 0; i < req_count_; i++) { + waits[i].wait(); + } + + int last_req_count = req_count_; + req_count_ = 0; + + for (int i = 0; i < last_req_count; i++) { + if (!a[i]) { return false; } } - return ok; + return true; } bool RPCClient::Proceed() { @@ -124,7 +149,6 @@ bool RPCClient::Proceed() { c->Process(); delete c; - req_count_--; return true; }