未验证 提交 078223b3 编写于 作者: G gongweibao 提交者: GitHub

Add rpc timeline. (#13900)

Add rpc timeline
上级 e3964e5a
文件模式从 100644 更改为 100755
...@@ -20,7 +20,7 @@ if(WITH_GRPC) ...@@ -20,7 +20,7 @@ if(WITH_GRPC)
DEPS grpc++_unsecure grpc_unsecure gpr cares zlib protobuf sendrecvop_grpc scope profiler math_function SERIAL) DEPS grpc++_unsecure grpc_unsecure gpr cares zlib protobuf sendrecvop_grpc scope profiler math_function SERIAL)
cc_test(rpc_server_test SRCS rpc_server_test.cc cc_test(rpc_server_test SRCS rpc_server_test.cc
DEPS sendrecvop_grpc grpc++_unsecure grpc_unsecure gpr cares zlib protobuf executor proto_desc lookup_sparse_table_op SERIAL) DEPS sendrecvop_grpc grpc++_unsecure grpc_unsecure gpr cares zlib protobuf executor proto_desc lookup_sparse_table_op SERIAL)
cc_test(varhandle_test SRCS varhandle_test.cc) cc_test(varhandle_test SRCS varhandle_test.cc DEPS profiler)
return() return()
endif() endif()
......
...@@ -73,10 +73,11 @@ VarHandlePtr GRPCClient::AsyncSendVar(const std::string& ep, ...@@ -73,10 +73,11 @@ VarHandlePtr GRPCClient::AsyncSendVar(const std::string& ep,
const framework::Scope* p_scope = &scope; const framework::Scope* p_scope = &scope;
const auto ch = GetChannel(ep_val); const auto ch = GetChannel(ep_val);
SendProcessor* s = new SendProcessor(ch); SendProcessor* s = new SendProcessor(ch);
VarHandlePtr h(new VarHandle(ep, "Send", var_name_val, p_ctx, p_scope)); const std::string method = "SendRPC";
VarHandlePtr h(new VarHandle(ep, method, var_name_val, p_ctx, p_scope));
s->Prepare(h, time_out); s->Prepare(h, time_out);
framework::AsyncIO([var_name_val, p_scope, p_ctx, s, this] { framework::AsyncIO([var_name_val, p_scope, p_ctx, s, method, h, this] {
auto* var = p_scope->FindVar(var_name_val); auto* var = p_scope->FindVar(var_name_val);
::grpc::ByteBuffer req; ::grpc::ByteBuffer req;
...@@ -87,10 +88,16 @@ VarHandlePtr GRPCClient::AsyncSendVar(const std::string& ep, ...@@ -87,10 +88,16 @@ VarHandlePtr GRPCClient::AsyncSendVar(const std::string& ep,
// stub context // stub context
s->response_call_back_ = nullptr; s->response_call_back_ = nullptr;
platform::RecordEvent record_event(method, p_ctx);
auto call = s->stub_g_.PrepareUnaryCall( auto call = s->stub_g_.PrepareUnaryCall(
s->context_.get(), "/sendrecv.SendRecvService/SendVariable", req, &cq_); s->context_.get(), "/sendrecv.SendRecvService/SendVariable", req, &cq_);
call->StartCall(); call->StartCall();
call->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s)); call->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
if (UNLIKELY(platform::IsProfileEnabled())) {
h->Wait();
}
}); });
req_count_++; req_count_++;
...@@ -122,10 +129,11 @@ VarHandlePtr GRPCClient::AsyncGetVar(const std::string& ep, ...@@ -122,10 +129,11 @@ VarHandlePtr GRPCClient::AsyncGetVar(const std::string& ep,
const framework::Scope* p_scope = &scope; const framework::Scope* p_scope = &scope;
const auto ch = GetChannel(ep_val); const auto ch = GetChannel(ep_val);
GetProcessor* s = new GetProcessor(ch); GetProcessor* s = new GetProcessor(ch);
VarHandlePtr h(new VarHandle(ep, "Get", var_name_val, p_ctx, p_scope)); const std::string method = "GetRPC";
VarHandlePtr h(new VarHandle(ep, method, var_name_val, p_ctx, p_scope));
s->Prepare(h, time_out); s->Prepare(h, time_out);
framework::AsyncIO([var_name_val, s, this] { framework::AsyncIO([var_name_val, s, method, p_ctx, h, this] {
// prepare input // prepare input
sendrecv::VariableMessage req; sendrecv::VariableMessage req;
req.set_varname(var_name_val); req.set_varname(var_name_val);
...@@ -137,10 +145,16 @@ VarHandlePtr GRPCClient::AsyncGetVar(const std::string& ep, ...@@ -137,10 +145,16 @@ VarHandlePtr GRPCClient::AsyncGetVar(const std::string& ep,
// stub context // stub context
s->response_call_back_ = ProcGetResponse; s->response_call_back_ = ProcGetResponse;
platform::RecordEvent record_event(method, p_ctx);
auto call = s->stub_g_.PrepareUnaryCall( auto call = s->stub_g_.PrepareUnaryCall(
s->context_.get(), "/sendrecv.SendRecvService/GetVariable", buf, &cq_); s->context_.get(), "/sendrecv.SendRecvService/GetVariable", buf, &cq_);
call->StartCall(); call->StartCall();
call->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s)); call->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
if (UNLIKELY(platform::IsProfileEnabled())) {
h->Wait();
}
}); });
req_count_++; req_count_++;
...@@ -161,12 +175,14 @@ VarHandlePtr GRPCClient::AsyncPrefetchVar(const std::string& ep, ...@@ -161,12 +175,14 @@ VarHandlePtr GRPCClient::AsyncPrefetchVar(const std::string& ep,
const framework::Scope* p_scope = &scope; const framework::Scope* p_scope = &scope;
const auto ch = GetChannel(ep_val); const auto ch = GetChannel(ep_val);
GetProcessor* s = new GetProcessor(ch); GetProcessor* s = new GetProcessor(ch);
VarHandlePtr h(
new VarHandle(ep, "Prefetch", out_var_name_val, p_ctx, p_scope)); const std::string method = "PrefetchRPC";
VarHandlePtr h(new VarHandle(ep, method, out_var_name_val, p_ctx, p_scope));
s->Prepare(h, time_out); s->Prepare(h, time_out);
framework::AsyncIO([in_var_name_val, out_var_name_val, ep_val, p_scope, p_ctx, framework::AsyncIO([in_var_name_val, out_var_name_val, ep_val, p_scope, p_ctx,
s, this] { s, method, h, this] {
auto* var = p_scope->FindVar(in_var_name_val); auto* var = p_scope->FindVar(in_var_name_val);
::grpc::ByteBuffer req; ::grpc::ByteBuffer req;
...@@ -177,11 +193,17 @@ VarHandlePtr GRPCClient::AsyncPrefetchVar(const std::string& ep, ...@@ -177,11 +193,17 @@ VarHandlePtr GRPCClient::AsyncPrefetchVar(const std::string& ep,
// stub context // stub context
s->response_call_back_ = ProcGetResponse; s->response_call_back_ = ProcGetResponse;
platform::RecordEvent record_event(method, p_ctx);
auto call = s->stub_g_.PrepareUnaryCall( auto call = s->stub_g_.PrepareUnaryCall(
s->context_.get(), "/sendrecv.SendRecvService/PrefetchVariable", req, s->context_.get(), "/sendrecv.SendRecvService/PrefetchVariable", req,
&cq_); &cq_);
call->StartCall(); call->StartCall();
call->Finish(&s->reply_, &s->status_, static_cast<void*>(s)); call->Finish(&s->reply_, &s->status_, static_cast<void*>(s));
if (UNLIKELY(platform::IsProfileEnabled())) {
h->Wait();
}
}); });
req_count_++; req_count_++;
...@@ -193,15 +215,24 @@ VarHandlePtr GRPCClient::AsyncSendBatchBarrier(const std::string& ep, ...@@ -193,15 +215,24 @@ VarHandlePtr GRPCClient::AsyncSendBatchBarrier(const std::string& ep,
const auto ch = GetChannel(ep); const auto ch = GetChannel(ep);
BatchBarrierProcessor* s = new BatchBarrierProcessor(ch); BatchBarrierProcessor* s = new BatchBarrierProcessor(ch);
VarHandlePtr h(new VarHandle(ep, "BatchBarrier", BATCH_BARRIER_MESSAGE, const std::string method = "BatchBarrierRPC";
nullptr, nullptr)); VarHandlePtr h(
new VarHandle(ep, method, BATCH_BARRIER_MESSAGE, nullptr, nullptr));
s->Prepare(h, time_out); s->Prepare(h, time_out);
sendrecv::VariableMessage req; sendrecv::VariableMessage req;
req.set_varname(BATCH_BARRIER_MESSAGE); req.set_varname(BATCH_BARRIER_MESSAGE);
platform::RecordEvent record_event(method, nullptr);
auto rpc = s->stub_->AsyncSendVariable(s->context_.get(), req, &cq_); auto rpc = s->stub_->AsyncSendVariable(s->context_.get(), req, &cq_);
rpc->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s)); rpc->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
req_count_++; req_count_++;
if (UNLIKELY(platform::IsProfileEnabled())) {
h->Wait();
}
return h; return h;
} }
...@@ -209,15 +240,24 @@ VarHandlePtr GRPCClient::AsyncSendFetchBarrier(const std::string& ep, ...@@ -209,15 +240,24 @@ VarHandlePtr GRPCClient::AsyncSendFetchBarrier(const std::string& ep,
int64_t time_out) { int64_t time_out) {
const auto ch = GetChannel(ep); const auto ch = GetChannel(ep);
FetchBarrierProcessor* s = new FetchBarrierProcessor(ch); FetchBarrierProcessor* s = new FetchBarrierProcessor(ch);
VarHandlePtr h(new VarHandle(ep, "FetchBarrier", FETCH_BARRIER_MESSAGE, const std::string method = "FetchBarrierRPC";
nullptr, nullptr)); VarHandlePtr h(
new VarHandle(ep, method, FETCH_BARRIER_MESSAGE, nullptr, nullptr));
s->Prepare(h, time_out); s->Prepare(h, time_out);
sendrecv::VariableMessage req; sendrecv::VariableMessage req;
req.set_varname(FETCH_BARRIER_MESSAGE); req.set_varname(FETCH_BARRIER_MESSAGE);
platform::RecordEvent record_event(method, nullptr);
auto rpc = s->stub_->AsyncGetVariable(s->context_.get(), req, &cq_); auto rpc = s->stub_->AsyncGetVariable(s->context_.get(), req, &cq_);
rpc->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s)); rpc->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
req_count_++; req_count_++;
if (UNLIKELY(platform::IsProfileEnabled())) {
h->Wait();
}
return h; return h;
} }
...@@ -226,15 +266,23 @@ VarHandlePtr GRPCClient::AsyncSendComplete(const std::string& ep, ...@@ -226,15 +266,23 @@ VarHandlePtr GRPCClient::AsyncSendComplete(const std::string& ep,
const auto ch = GetChannel(ep); const auto ch = GetChannel(ep);
BatchBarrierProcessor* s = new BatchBarrierProcessor(ch); BatchBarrierProcessor* s = new BatchBarrierProcessor(ch);
VarHandlePtr h( const std::string method = "SendCompleteRPC";
new VarHandle(ep, "SendComplete", COMPLETE_MESSAGE, nullptr, nullptr)); VarHandlePtr h(new VarHandle(ep, method, COMPLETE_MESSAGE, nullptr, nullptr));
s->Prepare(h, time_out); s->Prepare(h, time_out);
sendrecv::VariableMessage req; sendrecv::VariableMessage req;
req.set_varname(COMPLETE_MESSAGE); req.set_varname(COMPLETE_MESSAGE);
platform::RecordEvent record_event(method, nullptr);
auto rpc = s->stub_->AsyncSendVariable(s->context_.get(), req, &cq_); auto rpc = s->stub_->AsyncSendVariable(s->context_.get(), req, &cq_);
rpc->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s)); rpc->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
req_count_++; req_count_++;
if (UNLIKELY(platform::IsProfileEnabled())) {
h->Wait();
}
return h; return h;
} }
...@@ -244,17 +292,27 @@ VarHandlePtr GRPCClient::AsyncCheckpointNotify(const std::string& ep, ...@@ -244,17 +292,27 @@ VarHandlePtr GRPCClient::AsyncCheckpointNotify(const std::string& ep,
const auto ch = GetChannel(ep); const auto ch = GetChannel(ep);
CheckpointNotifyProcessor* s = new CheckpointNotifyProcessor(ch); CheckpointNotifyProcessor* s = new CheckpointNotifyProcessor(ch);
VarHandlePtr h(new VarHandle(ep, "CheckPointNotify", CHECKPOINT_SAVE_MESSAGE,
nullptr, nullptr)); const std::string method = "CheckPointNotifyRPC";
VarHandlePtr h(
new VarHandle(ep, method, CHECKPOINT_SAVE_MESSAGE, nullptr, nullptr));
s->Prepare(h, time_out); s->Prepare(h, time_out);
sendrecv::VariableMessage req; sendrecv::VariableMessage req;
req.set_varname(CHECKPOINT_SAVE_MESSAGE); req.set_varname(CHECKPOINT_SAVE_MESSAGE);
req.set_out_varname(dir); req.set_out_varname(dir);
platform::RecordEvent record_event(method, nullptr);
auto rpc = s->stub_->AsyncCheckpointNotify(s->context_.get(), req, &cq_); auto rpc = s->stub_->AsyncCheckpointNotify(s->context_.get(), req, &cq_);
rpc->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s)); rpc->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
req_count_++; req_count_++;
if (UNLIKELY(platform::IsProfileEnabled())) {
h->Wait();
}
return h; return h;
} }
...@@ -273,6 +331,7 @@ void GRPCClient::Proceed() { ...@@ -273,6 +331,7 @@ void GRPCClient::Proceed() {
BaseProcessor* c = static_cast<BaseProcessor*>(tag); BaseProcessor* c = static_cast<BaseProcessor*>(tag);
GPR_ASSERT(ok); GPR_ASSERT(ok);
PADDLE_ENFORCE(c); PADDLE_ENFORCE(c);
if (c->status_.ok()) { if (c->status_.ok()) {
VLOG(3) << c->GetVarHandlePtr()->String() << " process"; VLOG(3) << c->GetVarHandlePtr()->String() << " process";
c->Process(); c->Process();
......
...@@ -36,6 +36,7 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var, ...@@ -36,6 +36,7 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
const platform::DeviceContext& ctx, const platform::DeviceContext& ctx,
::grpc::ByteBuffer* msg, ::grpc::ByteBuffer* msg,
const std::string& out_name) { const std::string& out_name) {
platform::RecordEvent record_event("serial", &ctx);
// Default DestroyCallback does nothing, When using GPU // Default DestroyCallback does nothing, When using GPU
// the CPU buffer need to be freed. // the CPU buffer need to be freed.
DestroyCallback destroy_callback = [](void* backing) {}; DestroyCallback destroy_callback = [](void* backing) {};
...@@ -147,6 +148,7 @@ void DeserializeFromByteBuffer(const ::grpc::ByteBuffer& msg, ...@@ -147,6 +148,7 @@ void DeserializeFromByteBuffer(const ::grpc::ByteBuffer& msg,
const platform::DeviceContext& ctx, const platform::DeviceContext& ctx,
const framework::Scope* scope, const framework::Scope* scope,
framework::Variable** var) { framework::Variable** var) {
platform::RecordEvent record_event("deserial", &ctx);
operators::distributed::GRPCVariableResponse resp(scope, &ctx); operators::distributed::GRPCVariableResponse resp(scope, &ctx);
PADDLE_ENFORCE(resp.Parse(msg) == 0, "parse bytebuffer to tensor error!"); PADDLE_ENFORCE(resp.Parse(msg) == 0, "parse bytebuffer to tensor error!");
*var = resp.GetVar(); *var = resp.GetVar();
......
...@@ -66,7 +66,7 @@ static void ParallelExecuteBlocks( ...@@ -66,7 +66,7 @@ static void ParallelExecuteBlocks(
<< "pointer: " << prepared[run_block].get(); << "pointer: " << prepared[run_block].get();
executor->RunPreparedContext(prepared[run_block].get(), scope); executor->RunPreparedContext(prepared[run_block].get(), scope);
} catch (const std::exception &e) { } catch (const std::exception &e) {
LOG(ERROR) << "run sub program error " << e.what(); LOG(FATAL) << "run sub program:" << idx << " error " << e.what();
} }
})); }));
} }
......
...@@ -71,6 +71,7 @@ void PopEvent(const std::string& name, const DeviceContext* dev_ctx); ...@@ -71,6 +71,7 @@ void PopEvent(const std::string& name, const DeviceContext* dev_ctx);
#if !defined(_WIN32) #if !defined(_WIN32)
struct RecordEvent { struct RecordEvent {
// dev_ctx can be set to nullptr if device is cpu.
RecordEvent(const std::string& name, const DeviceContext* dev_ctx); RecordEvent(const std::string& name, const DeviceContext* dev_ctx);
~RecordEvent(); ~RecordEvent();
......
...@@ -91,6 +91,8 @@ class TestDistSimnetBow2x2SparseAsync(TestDistBase): ...@@ -91,6 +91,8 @@ class TestDistSimnetBow2x2SparseAsync(TestDistBase):
need_envs=need_envs) need_envs=need_envs)
# FIXME(tangwei): Learningrate variable is not created on pserver.
"""
class TestDistSimnetBow2x2LookupTableSync(TestDistBase): class TestDistSimnetBow2x2LookupTableSync(TestDistBase):
def _setup_config(self): def _setup_config(self):
self._sync_mode = True self._sync_mode = True
...@@ -105,7 +107,7 @@ class TestDistSimnetBow2x2LookupTableSync(TestDistBase): ...@@ -105,7 +107,7 @@ class TestDistSimnetBow2x2LookupTableSync(TestDistBase):
self.check_with_place( self.check_with_place(
"dist_simnet_bow.py", "dist_simnet_bow.py",
delta=1e-5, delta=1e-5,
check_error_log=False, check_error_log=True,
need_envs=need_envs) need_envs=need_envs)
...@@ -143,7 +145,7 @@ class TestDistSimnetBow2x2LookupTableNotContainLRSync(TestDistBase): ...@@ -143,7 +145,7 @@ class TestDistSimnetBow2x2LookupTableNotContainLRSync(TestDistBase):
delta=1e-5, delta=1e-5,
check_error_log=False, check_error_log=False,
need_envs=need_envs) need_envs=need_envs)
"""
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册