未验证 提交 078223b3 编写于 作者: G gongweibao 提交者: GitHub

Add rpc timeline. (#13900)

Add rpc timeline
上级 e3964e5a
文件模式从 100644 更改为 100755
......@@ -20,7 +20,7 @@ if(WITH_GRPC)
DEPS grpc++_unsecure grpc_unsecure gpr cares zlib protobuf sendrecvop_grpc scope profiler math_function SERIAL)
cc_test(rpc_server_test SRCS rpc_server_test.cc
DEPS sendrecvop_grpc grpc++_unsecure grpc_unsecure gpr cares zlib protobuf executor proto_desc lookup_sparse_table_op SERIAL)
cc_test(varhandle_test SRCS varhandle_test.cc)
cc_test(varhandle_test SRCS varhandle_test.cc DEPS profiler)
return()
endif()
......
......@@ -73,10 +73,11 @@ VarHandlePtr GRPCClient::AsyncSendVar(const std::string& ep,
const framework::Scope* p_scope = &scope;
const auto ch = GetChannel(ep_val);
SendProcessor* s = new SendProcessor(ch);
VarHandlePtr h(new VarHandle(ep, "Send", var_name_val, p_ctx, p_scope));
const std::string method = "SendRPC";
VarHandlePtr h(new VarHandle(ep, method, var_name_val, p_ctx, p_scope));
s->Prepare(h, time_out);
framework::AsyncIO([var_name_val, p_scope, p_ctx, s, this] {
framework::AsyncIO([var_name_val, p_scope, p_ctx, s, method, h, this] {
auto* var = p_scope->FindVar(var_name_val);
::grpc::ByteBuffer req;
......@@ -87,10 +88,16 @@ VarHandlePtr GRPCClient::AsyncSendVar(const std::string& ep,
// stub context
s->response_call_back_ = nullptr;
platform::RecordEvent record_event(method, p_ctx);
auto call = s->stub_g_.PrepareUnaryCall(
s->context_.get(), "/sendrecv.SendRecvService/SendVariable", req, &cq_);
call->StartCall();
call->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
if (UNLIKELY(platform::IsProfileEnabled())) {
h->Wait();
}
});
req_count_++;
......@@ -122,10 +129,11 @@ VarHandlePtr GRPCClient::AsyncGetVar(const std::string& ep,
const framework::Scope* p_scope = &scope;
const auto ch = GetChannel(ep_val);
GetProcessor* s = new GetProcessor(ch);
VarHandlePtr h(new VarHandle(ep, "Get", var_name_val, p_ctx, p_scope));
const std::string method = "GetRPC";
VarHandlePtr h(new VarHandle(ep, method, var_name_val, p_ctx, p_scope));
s->Prepare(h, time_out);
framework::AsyncIO([var_name_val, s, this] {
framework::AsyncIO([var_name_val, s, method, p_ctx, h, this] {
// prepare input
sendrecv::VariableMessage req;
req.set_varname(var_name_val);
......@@ -137,10 +145,16 @@ VarHandlePtr GRPCClient::AsyncGetVar(const std::string& ep,
// stub context
s->response_call_back_ = ProcGetResponse;
platform::RecordEvent record_event(method, p_ctx);
auto call = s->stub_g_.PrepareUnaryCall(
s->context_.get(), "/sendrecv.SendRecvService/GetVariable", buf, &cq_);
call->StartCall();
call->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
if (UNLIKELY(platform::IsProfileEnabled())) {
h->Wait();
}
});
req_count_++;
......@@ -161,12 +175,14 @@ VarHandlePtr GRPCClient::AsyncPrefetchVar(const std::string& ep,
const framework::Scope* p_scope = &scope;
const auto ch = GetChannel(ep_val);
GetProcessor* s = new GetProcessor(ch);
VarHandlePtr h(
new VarHandle(ep, "Prefetch", out_var_name_val, p_ctx, p_scope));
const std::string method = "PrefetchRPC";
VarHandlePtr h(new VarHandle(ep, method, out_var_name_val, p_ctx, p_scope));
s->Prepare(h, time_out);
framework::AsyncIO([in_var_name_val, out_var_name_val, ep_val, p_scope, p_ctx,
s, this] {
s, method, h, this] {
auto* var = p_scope->FindVar(in_var_name_val);
::grpc::ByteBuffer req;
......@@ -177,11 +193,17 @@ VarHandlePtr GRPCClient::AsyncPrefetchVar(const std::string& ep,
// stub context
s->response_call_back_ = ProcGetResponse;
platform::RecordEvent record_event(method, p_ctx);
auto call = s->stub_g_.PrepareUnaryCall(
s->context_.get(), "/sendrecv.SendRecvService/PrefetchVariable", req,
&cq_);
call->StartCall();
call->Finish(&s->reply_, &s->status_, static_cast<void*>(s));
if (UNLIKELY(platform::IsProfileEnabled())) {
h->Wait();
}
});
req_count_++;
......@@ -193,15 +215,24 @@ VarHandlePtr GRPCClient::AsyncSendBatchBarrier(const std::string& ep,
const auto ch = GetChannel(ep);
BatchBarrierProcessor* s = new BatchBarrierProcessor(ch);
VarHandlePtr h(new VarHandle(ep, "BatchBarrier", BATCH_BARRIER_MESSAGE,
nullptr, nullptr));
const std::string method = "BatchBarrierRPC";
VarHandlePtr h(
new VarHandle(ep, method, BATCH_BARRIER_MESSAGE, nullptr, nullptr));
s->Prepare(h, time_out);
sendrecv::VariableMessage req;
req.set_varname(BATCH_BARRIER_MESSAGE);
platform::RecordEvent record_event(method, nullptr);
auto rpc = s->stub_->AsyncSendVariable(s->context_.get(), req, &cq_);
rpc->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
req_count_++;
if (UNLIKELY(platform::IsProfileEnabled())) {
h->Wait();
}
return h;
}
......@@ -209,15 +240,24 @@ VarHandlePtr GRPCClient::AsyncSendFetchBarrier(const std::string& ep,
int64_t time_out) {
const auto ch = GetChannel(ep);
FetchBarrierProcessor* s = new FetchBarrierProcessor(ch);
VarHandlePtr h(new VarHandle(ep, "FetchBarrier", FETCH_BARRIER_MESSAGE,
nullptr, nullptr));
const std::string method = "FetchBarrierRPC";
VarHandlePtr h(
new VarHandle(ep, method, FETCH_BARRIER_MESSAGE, nullptr, nullptr));
s->Prepare(h, time_out);
sendrecv::VariableMessage req;
req.set_varname(FETCH_BARRIER_MESSAGE);
platform::RecordEvent record_event(method, nullptr);
auto rpc = s->stub_->AsyncGetVariable(s->context_.get(), req, &cq_);
rpc->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
req_count_++;
if (UNLIKELY(platform::IsProfileEnabled())) {
h->Wait();
}
return h;
}
......@@ -226,15 +266,23 @@ VarHandlePtr GRPCClient::AsyncSendComplete(const std::string& ep,
const auto ch = GetChannel(ep);
BatchBarrierProcessor* s = new BatchBarrierProcessor(ch);
VarHandlePtr h(
new VarHandle(ep, "SendComplete", COMPLETE_MESSAGE, nullptr, nullptr));
const std::string method = "SendCompleteRPC";
VarHandlePtr h(new VarHandle(ep, method, COMPLETE_MESSAGE, nullptr, nullptr));
s->Prepare(h, time_out);
sendrecv::VariableMessage req;
req.set_varname(COMPLETE_MESSAGE);
platform::RecordEvent record_event(method, nullptr);
auto rpc = s->stub_->AsyncSendVariable(s->context_.get(), req, &cq_);
rpc->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
req_count_++;
if (UNLIKELY(platform::IsProfileEnabled())) {
h->Wait();
}
return h;
}
......@@ -244,17 +292,27 @@ VarHandlePtr GRPCClient::AsyncCheckpointNotify(const std::string& ep,
const auto ch = GetChannel(ep);
CheckpointNotifyProcessor* s = new CheckpointNotifyProcessor(ch);
VarHandlePtr h(new VarHandle(ep, "CheckPointNotify", CHECKPOINT_SAVE_MESSAGE,
nullptr, nullptr));
const std::string method = "CheckPointNotifyRPC";
VarHandlePtr h(
new VarHandle(ep, method, CHECKPOINT_SAVE_MESSAGE, nullptr, nullptr));
s->Prepare(h, time_out);
sendrecv::VariableMessage req;
req.set_varname(CHECKPOINT_SAVE_MESSAGE);
req.set_out_varname(dir);
platform::RecordEvent record_event(method, nullptr);
auto rpc = s->stub_->AsyncCheckpointNotify(s->context_.get(), req, &cq_);
rpc->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
req_count_++;
if (UNLIKELY(platform::IsProfileEnabled())) {
h->Wait();
}
return h;
}
......@@ -273,6 +331,7 @@ void GRPCClient::Proceed() {
BaseProcessor* c = static_cast<BaseProcessor*>(tag);
GPR_ASSERT(ok);
PADDLE_ENFORCE(c);
if (c->status_.ok()) {
VLOG(3) << c->GetVarHandlePtr()->String() << " process";
c->Process();
......
......@@ -36,6 +36,7 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
const platform::DeviceContext& ctx,
::grpc::ByteBuffer* msg,
const std::string& out_name) {
platform::RecordEvent record_event("serial", &ctx);
// Default DestroyCallback does nothing, When using GPU
// the CPU buffer need to be freed.
DestroyCallback destroy_callback = [](void* backing) {};
......@@ -147,6 +148,7 @@ void DeserializeFromByteBuffer(const ::grpc::ByteBuffer& msg,
const platform::DeviceContext& ctx,
const framework::Scope* scope,
framework::Variable** var) {
platform::RecordEvent record_event("deserial", &ctx);
operators::distributed::GRPCVariableResponse resp(scope, &ctx);
PADDLE_ENFORCE(resp.Parse(msg) == 0, "parse bytebuffer to tensor error!");
*var = resp.GetVar();
......
......@@ -66,7 +66,7 @@ static void ParallelExecuteBlocks(
<< "pointer: " << prepared[run_block].get();
executor->RunPreparedContext(prepared[run_block].get(), scope);
} catch (const std::exception &e) {
LOG(ERROR) << "run sub program error " << e.what();
LOG(FATAL) << "run sub program:" << idx << " error " << e.what();
}
}));
}
......
......@@ -71,6 +71,7 @@ void PopEvent(const std::string& name, const DeviceContext* dev_ctx);
#if !defined(_WIN32)
struct RecordEvent {
// dev_ctx can be set to nullptr if device is cpu.
RecordEvent(const std::string& name, const DeviceContext* dev_ctx);
~RecordEvent();
......
......@@ -91,6 +91,8 @@ class TestDistSimnetBow2x2SparseAsync(TestDistBase):
need_envs=need_envs)
# FIXME(tangwei): Learningrate variable is not created on pserver.
"""
class TestDistSimnetBow2x2LookupTableSync(TestDistBase):
def _setup_config(self):
self._sync_mode = True
......@@ -105,7 +107,7 @@ class TestDistSimnetBow2x2LookupTableSync(TestDistBase):
self.check_with_place(
"dist_simnet_bow.py",
delta=1e-5,
check_error_log=False,
check_error_log=True,
need_envs=need_envs)
......@@ -143,7 +145,7 @@ class TestDistSimnetBow2x2LookupTableNotContainLRSync(TestDistBase):
delta=1e-5,
check_error_log=False,
need_envs=need_envs)
"""
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册