未验证 提交 03d4665f 编写于 作者: 1 123malin 提交者: GitHub

prefetch optimize (#29095)

* test=develop, optimize async prefetch
上级 7c61ba3a
......@@ -162,6 +162,18 @@ void AsyncCommunicator::SendByCommunicator() {
auto after_send = GetCurrentUS();
VLOG(3) << "send " << var_name << " use time "
<< after_send - after_merge;
if (var_name.rfind("@GRAD") != var_name.size() - 5) return;
auto recv_param = var_name.substr(0, var_name.size() - 5);
if (recv_varname_to_ctx_.find(recv_param) == recv_varname_to_ctx_.end())
return;
auto recv_functor = distributed::ParameterRecv<float>();
recv_functor(recv_varname_to_ctx_.at(recv_param), *recv_scope_);
auto after_recv = GetCurrentUS();
VLOG(3) << "recv " << recv_param << " use time "
<< after_recv - after_send;
};
task_futures.emplace_back(send_threadpool_->enqueue(std::move(send_task)));
}
......
......@@ -23,6 +23,7 @@ limitations under the License. */
#include "paddle/fluid/platform/port.h"
#include "paddle/fluid/platform/profiler.h"
DEFINE_int32(rpc_client_threads, 2, "");
DECLARE_bool(rpc_disable_reuse_port);
namespace paddle {
......@@ -32,10 +33,11 @@ namespace distributed {
void GRPCClient::InitImpl() {
// start the client process thread
// TODO(wuyi): can make this in a threadpool
PADDLE_ENFORCE_EQ(client_thread_ == nullptr, true,
platform::errors::PreconditionNotMet(
"please not re init proceed thread"));
client_thread_.reset(new std::thread(std::bind(&GRPCClient::Proceed, this)));
client_threads_.resize(FLAGS_rpc_client_threads);
for (int i = 0; i < FLAGS_rpc_client_threads; i++) {
client_threads_[i].reset(
new std::thread(std::bind(&GRPCClient::Proceed, this)));
}
}
void GRPCClient::SendComplete() {
......@@ -62,7 +64,8 @@ GRPCClient::~GRPCClient() {
}
channels_.clear();
}
client_thread_->join();
for (size_t i = 0; i < client_threads_.size(); i++)
client_threads_[i]->join();
}
VarHandlePtr GRPCClient::AsyncSendVar(const std::string& ep,
......@@ -84,7 +87,7 @@ VarHandlePtr GRPCClient::AsyncSendVar(const std::string& ep,
VarHandlePtr h(new VarHandle(ep, method, var_name_val, p_ctx, p_scope));
s->Prepare(h, time_out);
framework::AsyncIO([var_name_val, p_scope, p_ctx, s, method, h, this] {
framework::Async([var_name_val, p_scope, p_ctx, s, method, h, this] {
auto* var = p_scope->FindVar(var_name_val);
::grpc::ByteBuffer req;
......@@ -206,8 +209,8 @@ VarHandlePtr GRPCClient::_AsyncGetVar(
VarHandlePtr h(new VarHandle(ep, method, out_varname_val, p_ctx, p_scope));
s->Prepare(h, time_out);
framework::AsyncIO([var_name_val, out_varname_val, table_name_val, s,
method, p_ctx, h, rpc_path, this] {
framework::Async([var_name_val, out_varname_val, table_name_val, s, method,
p_ctx, h, rpc_path, this] {
// prepare input
sendrecv::VariableMessage req;
req.set_varname(var_name_val);
......@@ -273,31 +276,29 @@ VarHandlePtr GRPCClient::AsyncPrefetchVar(const std::string& ep,
VarHandlePtr h(new VarHandle(ep, method, out_var_name_val, p_ctx, p_scope));
s->Prepare(h, kPrefetchTimeout);
framework::AsyncIO([in_var_name_val, out_var_name_val, ep_val, p_scope,
p_ctx, s, method, h, table_name_val, this] {
auto* var = p_scope->FindVar(in_var_name_val);
auto* var = p_scope->FindVar(in_var_name_val);
::grpc::ByteBuffer req;
SerializeToByteBuffer(in_var_name_val, var, *p_ctx, &req,
out_var_name_val, 0, table_name_val);
::grpc::ByteBuffer req;
SerializeToByteBuffer(in_var_name_val, var, *p_ctx, &req, out_var_name_val,
0, table_name_val);
VLOG(3) << s->GetVarHandlePtr()->String() << " begin";
VLOG(3) << s->GetVarHandlePtr()->String() << " begin";
// stub context
s->response_call_back_ = ProcGetResponse;
// stub context
s->response_call_back_ = ProcGetResponse;
platform::RecordRPCEvent record_event(method);
platform::RecordRPCEvent record_event(method);
auto call = s->stub_g_.PrepareUnaryCall(
s->context_.get(), "/sendrecv.SendRecvService/PrefetchVariable", req,
&cq_);
call->StartCall();
call->Finish(&s->reply_, &s->status_, static_cast<void*>(s));
auto call = s->stub_g_.PrepareUnaryCall(
s->context_.get(), "/sendrecv.SendRecvService/PrefetchVariable", req,
&cq_);
call->StartCall();
call->Finish(&s->reply_, &s->status_, static_cast<void*>(s));
if (UNLIKELY(platform::IsProfileEnabled())) {
h->Wait();
}
if (UNLIKELY(platform::IsProfileEnabled())) {
h->Wait();
}
});
req_count_++;
if (FLAGS_rpc_retry_times > 0 && retry_times_ < FLAGS_rpc_retry_times) {
......@@ -467,7 +468,7 @@ VarHandlePtr GRPCClient::AsyncDistributeNotify(
VarHandlePtr h(new VarHandle(ep, method, var_name_val, p_ctx, p_scope));
s->Prepare(h, time_out);
framework::AsyncIO([var_name_val, p_scope, p_ctx, s, method, h, this] {
framework::Async([var_name_val, p_scope, p_ctx, s, method, h, this] {
auto* var = p_scope->FindVar(var_name_val);
::grpc::ByteBuffer req;
......@@ -523,8 +524,8 @@ VarHandlePtr GRPCClient::AsyncSendAndRecv(const std::string& ep,
s->Prepare(h, time_out);
s->RecvPrepare(h_recv);
framework::AsyncIO([send_var_name_val, recv_var_name_val, table_name_val,
p_scope, p_ctx, s, method, h, this] {
framework::Async([send_var_name_val, recv_var_name_val, table_name_val,
p_scope, p_ctx, s, method, h, this] {
auto* send_var = p_scope->FindVar(send_var_name_val);
send_var->GetMutable<framework::LoDTensor>()->set_lod({});
::grpc::ByteBuffer buf;
......
......@@ -297,7 +297,7 @@ class GRPCClient : public RPCClient {
private:
grpc::CompletionQueue cq_;
std::unordered_map<std::string, std::shared_ptr<grpc::Channel>> channels_;
std::unique_ptr<std::thread> client_thread_{nullptr};
std::vector<std::unique_ptr<std::thread>> client_threads_;
// mutex for Wait client sync
std::mutex sync_mutex_;
......
......@@ -85,7 +85,7 @@ class RPCServer {
// class, and auto generate a condition id for this call
// to be used for the barrier.
void RegisterRPC(const std::string& rpc_name, RequestHandler* handler,
int thread_num = 5);
int thread_num = 1);
int GetThreadNum(const std::string& rpc_name) {
return rpc_thread_num_[rpc_name];
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册