From 03d4665f44aafdb5dc5861e901277837ab7a89d5 Mon Sep 17 00:00:00 2001 From: 123malin Date: Mon, 30 Nov 2020 14:33:16 +0800 Subject: [PATCH] prefetch optimize (#29095) * test=develop, optimize async prefetch --- .../operators/distributed/communicator.cc | 12 ++++ .../operators/distributed/grpc/grpc_client.cc | 61 ++++++++++--------- .../operators/distributed/grpc/grpc_client.h | 2 +- .../fluid/operators/distributed/rpc_server.h | 2 +- 4 files changed, 45 insertions(+), 32 deletions(-) diff --git a/paddle/fluid/operators/distributed/communicator.cc b/paddle/fluid/operators/distributed/communicator.cc index 07427bb69d9..54dd4208fdb 100644 --- a/paddle/fluid/operators/distributed/communicator.cc +++ b/paddle/fluid/operators/distributed/communicator.cc @@ -162,6 +162,18 @@ void AsyncCommunicator::SendByCommunicator() { auto after_send = GetCurrentUS(); VLOG(3) << "send " << var_name << " use time " << after_send - after_merge; + + if (var_name.rfind("@GRAD") != var_name.size() - 5) return; + + auto recv_param = var_name.substr(0, var_name.size() - 5); + if (recv_varname_to_ctx_.find(recv_param) == recv_varname_to_ctx_.end()) + return; + + auto recv_functor = distributed::ParameterRecv(); + recv_functor(recv_varname_to_ctx_.at(recv_param), *recv_scope_); + auto after_recv = GetCurrentUS(); + VLOG(3) << "recv " << recv_param << " use time " + << after_recv - after_send; }; task_futures.emplace_back(send_threadpool_->enqueue(std::move(send_task))); } diff --git a/paddle/fluid/operators/distributed/grpc/grpc_client.cc b/paddle/fluid/operators/distributed/grpc/grpc_client.cc index 0320ef6595d..97a9c14e4f1 100644 --- a/paddle/fluid/operators/distributed/grpc/grpc_client.cc +++ b/paddle/fluid/operators/distributed/grpc/grpc_client.cc @@ -23,6 +23,7 @@ limitations under the License. */ #include "paddle/fluid/platform/port.h" #include "paddle/fluid/platform/profiler.h" +DEFINE_int32(rpc_client_threads, 2, ""); DECLARE_bool(rpc_disable_reuse_port); namespace paddle { @@ -32,10 +33,11 @@ namespace distributed { void GRPCClient::InitImpl() { // start the client process thread // TODO(wuyi): can make this in a threadpool - PADDLE_ENFORCE_EQ(client_thread_ == nullptr, true, - platform::errors::PreconditionNotMet( - "please not re init proceed thread")); - client_thread_.reset(new std::thread(std::bind(&GRPCClient::Proceed, this))); + client_threads_.resize(FLAGS_rpc_client_threads); + for (int i = 0; i < FLAGS_rpc_client_threads; i++) { + client_threads_[i].reset( + new std::thread(std::bind(&GRPCClient::Proceed, this))); + } } void GRPCClient::SendComplete() { @@ -62,7 +64,8 @@ GRPCClient::~GRPCClient() { } channels_.clear(); } - client_thread_->join(); + for (size_t i = 0; i < client_threads_.size(); i++) + client_threads_[i]->join(); } VarHandlePtr GRPCClient::AsyncSendVar(const std::string& ep, @@ -84,7 +87,7 @@ VarHandlePtr GRPCClient::AsyncSendVar(const std::string& ep, VarHandlePtr h(new VarHandle(ep, method, var_name_val, p_ctx, p_scope)); s->Prepare(h, time_out); - framework::AsyncIO([var_name_val, p_scope, p_ctx, s, method, h, this] { + framework::Async([var_name_val, p_scope, p_ctx, s, method, h, this] { auto* var = p_scope->FindVar(var_name_val); ::grpc::ByteBuffer req; @@ -206,8 +209,8 @@ VarHandlePtr GRPCClient::_AsyncGetVar( VarHandlePtr h(new VarHandle(ep, method, out_varname_val, p_ctx, p_scope)); s->Prepare(h, time_out); - framework::AsyncIO([var_name_val, out_varname_val, table_name_val, s, - method, p_ctx, h, rpc_path, this] { + framework::Async([var_name_val, out_varname_val, table_name_val, s, method, + p_ctx, h, rpc_path, this] { // prepare input sendrecv::VariableMessage req; req.set_varname(var_name_val); @@ -273,31 +276,29 @@ VarHandlePtr GRPCClient::AsyncPrefetchVar(const std::string& ep, VarHandlePtr h(new VarHandle(ep, method, out_var_name_val, p_ctx, p_scope)); s->Prepare(h, kPrefetchTimeout); - framework::AsyncIO([in_var_name_val, out_var_name_val, ep_val, p_scope, - p_ctx, s, method, h, table_name_val, this] { - auto* var = p_scope->FindVar(in_var_name_val); + auto* var = p_scope->FindVar(in_var_name_val); - ::grpc::ByteBuffer req; - SerializeToByteBuffer(in_var_name_val, var, *p_ctx, &req, - out_var_name_val, 0, table_name_val); + ::grpc::ByteBuffer req; + SerializeToByteBuffer(in_var_name_val, var, *p_ctx, &req, out_var_name_val, + 0, table_name_val); - VLOG(3) << s->GetVarHandlePtr()->String() << " begin"; + VLOG(3) << s->GetVarHandlePtr()->String() << " begin"; - // stub context - s->response_call_back_ = ProcGetResponse; + // stub context + s->response_call_back_ = ProcGetResponse; - platform::RecordRPCEvent record_event(method); + platform::RecordRPCEvent record_event(method); - auto call = s->stub_g_.PrepareUnaryCall( - s->context_.get(), "/sendrecv.SendRecvService/PrefetchVariable", req, - &cq_); - call->StartCall(); - call->Finish(&s->reply_, &s->status_, static_cast(s)); + auto call = s->stub_g_.PrepareUnaryCall( + s->context_.get(), "/sendrecv.SendRecvService/PrefetchVariable", req, + &cq_); + call->StartCall(); + call->Finish(&s->reply_, &s->status_, static_cast(s)); + + if (UNLIKELY(platform::IsProfileEnabled())) { + h->Wait(); + } - if (UNLIKELY(platform::IsProfileEnabled())) { - h->Wait(); - } - }); req_count_++; if (FLAGS_rpc_retry_times > 0 && retry_times_ < FLAGS_rpc_retry_times) { @@ -467,7 +468,7 @@ VarHandlePtr GRPCClient::AsyncDistributeNotify( VarHandlePtr h(new VarHandle(ep, method, var_name_val, p_ctx, p_scope)); s->Prepare(h, time_out); - framework::AsyncIO([var_name_val, p_scope, p_ctx, s, method, h, this] { + framework::Async([var_name_val, p_scope, p_ctx, s, method, h, this] { auto* var = p_scope->FindVar(var_name_val); ::grpc::ByteBuffer req; @@ -523,8 +524,8 @@ VarHandlePtr GRPCClient::AsyncSendAndRecv(const std::string& ep, s->Prepare(h, time_out); s->RecvPrepare(h_recv); - framework::AsyncIO([send_var_name_val, recv_var_name_val, table_name_val, - p_scope, p_ctx, s, method, h, this] { + framework::Async([send_var_name_val, recv_var_name_val, table_name_val, + p_scope, p_ctx, s, method, h, this] { auto* send_var = p_scope->FindVar(send_var_name_val); send_var->GetMutable()->set_lod({}); ::grpc::ByteBuffer buf; diff --git a/paddle/fluid/operators/distributed/grpc/grpc_client.h b/paddle/fluid/operators/distributed/grpc/grpc_client.h index 7b269f4d80c..5885f944b60 100644 --- a/paddle/fluid/operators/distributed/grpc/grpc_client.h +++ b/paddle/fluid/operators/distributed/grpc/grpc_client.h @@ -297,7 +297,7 @@ class GRPCClient : public RPCClient { private: grpc::CompletionQueue cq_; std::unordered_map> channels_; - std::unique_ptr client_thread_{nullptr}; + std::vector> client_threads_; // mutex for Wait client sync std::mutex sync_mutex_; diff --git a/paddle/fluid/operators/distributed/rpc_server.h b/paddle/fluid/operators/distributed/rpc_server.h index f83144f6268..2120260515e 100644 --- a/paddle/fluid/operators/distributed/rpc_server.h +++ b/paddle/fluid/operators/distributed/rpc_server.h @@ -85,7 +85,7 @@ class RPCServer { // class, and auto generate a condition id for this call // to be used for the barrier. void RegisterRPC(const std::string& rpc_name, RequestHandler* handler, - int thread_num = 5); + int thread_num = 1); int GetThreadNum(const std::string& rpc_name) { return rpc_thread_num_[rpc_name]; -- GitLab