未验证 提交 fe8f28c9 编写于 作者: G gongweibao 提交者: GitHub

Add GetVariableNoBarrier on brpc. (#15488)

上级 981fc2bd
......@@ -20,7 +20,7 @@ if(WITH_GRPC)
collective_client.cc collective_server.cc
${GRPC_SRCS}
PROTO send_recv.proto
DEPS lod_tensor selected_rows_functor memory)
DEPS lod_tensor selected_rows_functor memory scope ${GRPC_DEPS})
set_source_files_properties(grpc_serde_test.cc rpc_server_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set(RPC_DEPS sendrecvop_rpc ${GRPC_DEPS})
......@@ -32,15 +32,17 @@ else()
set(BRPC_SRCS brpc/brpc_client.cc brpc/brpc_server.cc brpc/brpc_sendrecvop_utils.cc brpc/brpc_variable_response.cc brpc/brpc_rdma_pool.cc)
set_source_files_properties(${BRPC_SRCS} parameter_prefetch.cc rpc_server_test.cc brpc/brpc_serde_test.cc collective_server.cc collective_server_test.cc collective_client.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set(BRPC_DEPS brpc ssl crypto protobuf leveldb snappystream snappy zlib)
brpc_library(sendrecvop_rpc SRCS sendrecvop_utils.cc
request_handler_impl.cc rpc_client.cc rpc_server.cc
variable_response.cc
collective_client.cc collective_server.cc
${BRPC_SRCS}
PROTO send_recv.proto
DEPS lod_tensor selected_rows memory)
DEPS lod_tensor selected_rows memory scope ${BRPC_DEPS})
set(RPC_DEPS sendrecvop_rpc brpc ssl crypto protobuf leveldb snappystream snappy zlib)
set(RPC_DEPS sendrecvop_rpc ${BRPC_DEPS})
cc_test(brpc_serde_test SRCS brpc/brpc_serde_test.cc
DEPS ${RPC_DEPS} gflags glog executor proto_desc lookup_sparse_table_op SERIAL)
endif()
......
......@@ -62,7 +62,7 @@ VarHandlePtr BRPCClient::AsyncSendVar(const std::string& ep,
const std::string var_name_val = var_name;
const framework::Scope* p_scope = &scope;
const auto ch_ptr = GetChannel(ep_val);
const std::string method = "SendRPC";
const std::string method = kSendRPC;
VarHandlePtr var_h(new VarHandle(ep, method, var_name_val, p_ctx, p_scope));
framework::AsyncIO([=] {
......@@ -156,15 +156,18 @@ VarHandlePtr BRPCClient::_AsyncGetVar(const std::string& ep,
const platform::DeviceContext& ctx,
const framework::Scope& scope,
const std::string& var_name,
const std::string& out_var_name,
const std::string& method_name,
int64_t time_out) {
const platform::DeviceContext* p_ctx = &ctx;
const std::string ep_val = ep;
const std::string var_name_val = var_name;
const std::string out_varname_val = out_var_name;
const framework::Scope* p_scope = &scope;
const auto ch_ptr = GetChannel(ep_val);
const std::string method = "GetRPC";
VarHandlePtr var_h(new VarHandle(ep, method, var_name_val, p_ctx, p_scope));
const std::string method = kGetRPC;
VarHandlePtr var_h(
new VarHandle(ep, method, out_varname_val, p_ctx, p_scope));
framework::AsyncIO([=] {
auto ch_ctx = ch_ptr->Pop();
......@@ -175,6 +178,7 @@ VarHandlePtr BRPCClient::_AsyncGetVar(const std::string& ep,
sendrecv::VariableMessage req;
req.set_varname(var_name_val);
req.set_out_varname(out_varname_val);
req.set_trainer_id(trainer_id_);
google::protobuf::Closure* done = brpc::NewCallback(
......@@ -182,8 +186,10 @@ VarHandlePtr BRPCClient::_AsyncGetVar(const std::string& ep,
platform::RecordRPCEvent record_event(method, p_ctx);
if (method_name == "GetMonomerVariable") {
if (method_name == kGetMonomerRPC) {
ch_ctx->stub->GetMonomerVariable(cntl, &req, response, done);
} else if (method_name == kGetNoBarrierRPC) {
ch_ctx->stub->GetVariableNoBarrier(cntl, &req, response, done);
} else {
ch_ctx->stub->GetVariable(cntl, &req, response, done);
}
......@@ -198,25 +204,39 @@ VarHandlePtr BRPCClient::_AsyncGetVar(const std::string& ep,
return var_h;
}
VarHandlePtr BRPCClient::AsyncGetVarNoBarrier(
const std::string& ep, const platform::DeviceContext& ctx,
const framework::Scope& scope, const std::string& var_name,
const std::string& out_var_name, int64_t time_out) {
std::string var_name_no_barrier =
string::Sprintf("%s%s", var_name, WITHOUT_BARRIER_MESSAGE);
return _AsyncGetVar(ep, ctx, scope, var_name_no_barrier, out_var_name,
kGetNoBarrierRPC, time_out);
}
VarHandlePtr BRPCClient::AsyncGetMonomerVariable(
const std::string& ep, const platform::DeviceContext& ctx,
const framework::Scope& scope, const std::string& var_name,
int64_t time_out) {
return _AsyncGetVar(ep, ctx, scope, var_name, "GetMonomerVariable", time_out);
return _AsyncGetVar(ep, ctx, scope, var_name, var_name, kGetMonomerRPC,
time_out);
}
VarHandlePtr BRPCClient::AsyncGetMonomerBarrier(const std::string& ep,
const std::string& var_name,
int64_t time_out) {
return AsyncSendMessage(ep, "GetMonomerBarrier", var_name, time_out);
return AsyncSendMessage(ep, kSendMonomerFetchBarrierRPC, var_name, time_out);
}
VarHandlePtr BRPCClient::AsyncGetVar(const std::string& ep,
const platform::DeviceContext& ctx,
const framework::Scope& scope,
const std::string& var_name,
const std::string& out_var_name,
int64_t time_out) {
return _AsyncGetVar(ep, ctx, scope, var_name, "GetVariable", time_out);
return _AsyncGetVar(ep, ctx, scope, var_name, out_var_name, kGetRPC,
time_out);
}
VarHandlePtr BRPCClient::AsyncPrefetchVar(const std::string& ep,
......@@ -234,7 +254,7 @@ VarHandlePtr BRPCClient::AsyncPrefetchVar(const std::string& ep,
const framework::Scope* p_scope = &scope;
const auto ch_ptr = GetChannel(ep_val);
const std::string method = "PrefetchRPC";
const std::string method = kPrefetchRPC;
VarHandlePtr var_h(
new VarHandle(ep, method, out_var_name_val, p_ctx, p_scope));
......@@ -270,7 +290,7 @@ VarHandlePtr BRPCClient::AsyncPrefetchVar(const std::string& ep,
VarHandlePtr BRPCClient::AsyncSendBatchBarrier(const std::string& ep,
int64_t time_out) {
return AsyncSendMessage(ep, "BatchBarrierRPC", BATCH_BARRIER_MESSAGE,
return AsyncSendMessage(ep, kBatchBarrierRPC, BATCH_BARRIER_MESSAGE,
time_out);
}
......@@ -286,7 +306,7 @@ VarHandlePtr BRPCClient::AsyncSendFetchBarrier(const std::string& ep,
sendrecv::VariableMessage req;
req.set_varname(FETCH_BARRIER_MESSAGE);
const std::string method = "FetchBarrierRPC";
const std::string method = kFetchBarrierRPC;
// var handle
VarHandlePtr var_h(
new VarHandle(ep, method, FETCH_BARRIER_MESSAGE, nullptr, nullptr));
......@@ -367,7 +387,7 @@ ChannelQueuePtr BRPCClient::GetChannel(const std::string& ep) {
VarHandlePtr BRPCClient::AsyncSendComplete(const std::string& ep,
int64_t time_out) {
return AsyncSendMessage(ep, "SendCompleteRPC", COMPLETE_MESSAGE, time_out);
return AsyncSendMessage(ep, kSendCompleteRPC, COMPLETE_MESSAGE, time_out);
}
void BRPCClient::SendComplete() {
......@@ -394,9 +414,9 @@ VarHandlePtr BRPCClient::AsyncSendVarMessage(
google::protobuf::Closure* done = brpc::NewCallback(
&HandleSendResponse, cntl, response, var_h, ch_ptr, ch_ctx, this);
if (method_name == "CheckPointNotifyRPC") {
if (method_name == kCheckPointNotifyRPC) {
ch_ctx->stub->CheckpointNotify(cntl, &req, response, done);
} else if (method_name == "GetMonomerBarrier") {
} else if (method_name == kSendMonomerFetchBarrierRPC) {
ch_ctx->stub->GetMonomerBarrier(cntl, &req, response, done);
} else {
ch_ctx->stub->SendVariable(cntl, &req, response, done);
......
......@@ -65,6 +65,7 @@ class BRPCClient : public RPCClient {
const platform::DeviceContext& ctx,
const framework::Scope& scope,
const std::string& var_name,
const std::string& out_var_name,
int64_t time_out = FLAGS_rpc_deadline) override;
VarHandlePtr AsyncGetMonomerBarrier(
......@@ -76,6 +77,13 @@ class BRPCClient : public RPCClient {
const framework::Scope& scope, const std::string& var_name,
int64_t time_out = FLAGS_rpc_deadline) override;
VarHandlePtr AsyncGetVarNoBarrier(const std::string& ep,
const platform::DeviceContext& ctx,
const framework::Scope& scope,
const std::string& var_name,
const std::string& out_varname,
int64_t time_out = FLAGS_rpc_deadline);
VarHandlePtr AsyncPrefetchVar(const std::string& ep,
const platform::DeviceContext& ctx,
const framework::Scope& scope,
......@@ -103,6 +111,7 @@ class BRPCClient : public RPCClient {
const platform::DeviceContext& ctx,
const framework::Scope& scope,
const std::string& var_name,
const std::string& out_var_name,
const std::string& method_name,
int64_t time_out = FLAGS_rpc_deadline);
......
......@@ -45,6 +45,13 @@ class BRPCServiceImpl : public SendRecvService {
rpc_server_->GetThreadNum(distributed::kRequestGet)));
}
it = rpc_call_map.find(distributed::kRequestGetNoBarrier);
if (it != rpc_call_map.end()) {
request_getnobarrier_h_ = it->second;
getnobarrier_threads_.reset(new paddle::framework::ThreadPool(
rpc_server_->GetThreadNum(distributed::kRequestGetNoBarrier)));
}
it = rpc_call_map.find(distributed::kRequestPrefetch);
if (it != rpc_call_map.end()) {
request_prefetch_h_ = it->second;
......@@ -112,6 +119,14 @@ class BRPCServiceImpl : public SendRecvService {
[=] { _GetVariable(cntl_butil, request, response, done); });
}
void GetVariableNoBarrier(google::protobuf::RpcController* cntl_butil,
const VariableMessage* request,
VariableMessage* response,
google::protobuf::Closure* done) override {
getnobarrier_threads_->Run(
[=] { _GetVariableNoBarrier(cntl_butil, request, response, done); });
}
void _GetVariable(google::protobuf::RpcController* cntl_butil,
const VariableMessage* request, VariableMessage* response,
google::protobuf::Closure* done) {
......@@ -122,23 +137,59 @@ class BRPCServiceImpl : public SendRecvService {
brpc::Controller* cntl = static_cast<brpc::Controller*>(cntl_butil);
std::string varname = request->varname();
std::string out_varname = request->out_varname();
VLOG(3) << "RequestGet varname:" << varname
<< ", out_varname:" << out_varname
<< ", trainer_id:" << request->trainer_id()
<< ", from:" << cntl->remote_side();
auto scope = request_get_h_->scope();
auto invar = scope->FindVar(varname);
paddle::framework::Variable* invar = nullptr;
int trainer_id = request->trainer_id();
paddle::framework::Variable* outvar = nullptr;
request_get_h_->Handle(varname, scope, invar, &outvar, trainer_id,
out_varname);
if (outvar) {
distributed::SerializeToIOBuf(out_varname, outvar,
*request_get_h_->dev_ctx(), response,
&cntl->response_attachment(), "", false);
}
}
void _GetVariableNoBarrier(google::protobuf::RpcController* cntl_butil,
const VariableMessage* request,
VariableMessage* response,
google::protobuf::Closure* done) {
PADDLE_ENFORCE(request_getnobarrier_h_ != nullptr,
"RequestGetNoBarrier handler should be registed first!");
brpc::ClosureGuard done_guard(done);
brpc::Controller* cntl = static_cast<brpc::Controller*>(cntl_butil);
std::string varname = request->varname();
std::string out_varname = request->out_varname();
int trainer_id = request->trainer_id();
VLOG(3) << "RequestGetNoBarrier varname:" << varname
<< ", out_varname:" << out_varname << ", trainer_id:" << trainer_id
<< ", from:" << cntl->remote_side();
auto scope = request_getnobarrier_h_->scope();
paddle::framework::Variable* invar = nullptr;
paddle::framework::Variable* outvar = nullptr;
request_get_h_->Handle(varname, scope, invar, &outvar, trainer_id);
request_getnobarrier_h_->Handle(varname, scope, invar, &outvar, trainer_id,
out_varname);
if (outvar) {
distributed::SerializeToIOBuf(varname, outvar, *request_get_h_->dev_ctx(),
response, &cntl->response_attachment(), "",
false);
distributed::SerializeToIOBuf(
out_varname, outvar, *request_getnobarrier_h_->dev_ctx(), response,
&cntl->response_attachment(), "", false);
}
}
void PrefetchVariable(google::protobuf::RpcController* cntl_butil,
const VariableMessage* request,
VariableMessage* response,
......@@ -282,6 +333,7 @@ class BRPCServiceImpl : public SendRecvService {
private:
distributed::RequestHandler* request_send_h_{nullptr};
distributed::RequestHandler* request_get_h_{nullptr};
distributed::RequestHandler* request_getnobarrier_h_{nullptr};
distributed::RequestHandler* request_prefetch_h_{nullptr};
distributed::RequestHandler* request_checkpoint_h_{nullptr};
distributed::RequestHandler* request_get_monomer_handler_h_{nullptr};
......@@ -289,9 +341,10 @@ class BRPCServiceImpl : public SendRecvService {
distributed::RPCServer* rpc_server_{nullptr};
// FIXME(gongwb): brpc should support process one rpce use one threadpool.
// FIXME(gongwb): brpc should support process one rpc use one threadpool.
std::unique_ptr<paddle::framework::ThreadPool> send_threads_;
std::unique_ptr<paddle::framework::ThreadPool> get_threads_;
std::unique_ptr<paddle::framework::ThreadPool> getnobarrier_threads_;
std::unique_ptr<paddle::framework::ThreadPool> prefetch_threads_;
std::unique_ptr<paddle::framework::ThreadPool> checkpoint_notify_threads_;
};
......
......@@ -328,7 +328,8 @@ function run_brpc_test() {
========================================
EOF
set +x
declare -a other_tests=("test_listen_and_serv_op" "system_allocator_test")
declare -a other_tests=("test_listen_and_serv_op" "system_allocator_test" \
"rpc_server_test" "varhandle_test" "collective_server_test" "brpc_serde_test")
all_tests=`ctest -N`
for t in "${other_tests[@]}"
......
......@@ -16,6 +16,7 @@ import sys
import time
import socket
from contextlib import closing
from six import string_types
def wait_server_ready(endpoints):
......@@ -32,6 +33,7 @@ def wait_server_ready(endpoints):
wait_server_ready(["127.0.0.1:8080", "127.0.0.1:8081"])
"""
assert not isinstance(endpoints, string_types)
while True:
all_ok = True
not_ready_endpoints = []
......@@ -45,7 +47,7 @@ def wait_server_ready(endpoints):
all_ok = False
not_ready_endpoints.append(ep)
if not all_ok:
sys.stderr.write("pserver not ready, wait 3 sec to retry...\n")
sys.stderr.write("server not ready, wait 3 sec to retry...\n")
sys.stderr.write("not ready endpoints:" + str(not_ready_endpoints) +
"\n")
sys.stderr.flush()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册