提交 268e9dc1 编写于 作者: Y Yancey1989

polish code

上级 ceefbf32
...@@ -84,8 +84,12 @@ void MultiDevSSAGraphBuilder::CreateOpHandleIOs(SSAGraph *result, ...@@ -84,8 +84,12 @@ void MultiDevSSAGraphBuilder::CreateOpHandleIOs(SSAGraph *result,
std::vector<std::string> MultiDevSSAGraphBuilder::FindDistTrainSendVars( std::vector<std::string> MultiDevSSAGraphBuilder::FindDistTrainSendVars(
const ProgramDesc &program) const { const ProgramDesc &program) const {
std::vector<std::string> send_vars; std::vector<std::string> send_vars;
// since parameters are all in block 0,
// it's enough to only scan send ops in block 0
for (auto *op : program.Block(0).AllOps()) { for (auto *op : program.Block(0).AllOps()) {
if (op->Type() == "send_vars" || op->Type() == "send") { // TODO(Yancey1989): use a graceful method to find send op,
// instead of the the hard code string
if (op->Type() == "send_vars") {
auto op_vars = op->InputArgumentNames(); auto op_vars = op->InputArgumentNames();
send_vars.reserve(send_vars.size() + send_vars.reserve(send_vars.size() +
std::distance(op_vars.begin(), op_vars.end())); std::distance(op_vars.begin(), op_vars.end()));
...@@ -99,7 +103,9 @@ std::vector<std::string> MultiDevSSAGraphBuilder::FindDistTrainRecvVars( ...@@ -99,7 +103,9 @@ std::vector<std::string> MultiDevSSAGraphBuilder::FindDistTrainRecvVars(
const ProgramDesc &program) const { const ProgramDesc &program) const {
std::vector<std::string> recv_vars; std::vector<std::string> recv_vars;
for (auto *op : program.Block(0).AllOps()) { for (auto *op : program.Block(0).AllOps()) {
if (op->Type() == "recv" || op->Type() == "send") { // TODO(Yancey1989): use a graceful method to find recv op,
// instead of the hard code string
if (op->Type() == "recv") {
auto op_vars = op->OutputArgumentNames(); auto op_vars = op->OutputArgumentNames();
recv_vars.reserve(recv_vars.size() + recv_vars.reserve(recv_vars.size() +
std::distance(op_vars.begin(), op_vars.end())); std::distance(op_vars.begin(), op_vars.end()));
...@@ -122,6 +128,9 @@ bool MultiDevSSAGraphBuilder::IsDistTrainOp( ...@@ -122,6 +128,9 @@ bool MultiDevSSAGraphBuilder::IsDistTrainOp(
auto checker = [](const std::vector<std::string> &opvars, auto checker = [](const std::vector<std::string> &opvars,
const std::vector<std::string> &rpc_vars) -> bool { const std::vector<std::string> &rpc_vars) -> bool {
for (auto &var : opvars) { for (auto &var : opvars) {
// a variable name with the suffix `.block` means it's a splited
// variable by (DistributeTranspiler)
// [python/paddle/fluid/transpiler/distribute_transpiler.py]
if (var.find(".block") != std::string::npos && if (var.find(".block") != std::string::npos &&
std::find(rpc_vars.begin(), rpc_vars.end(), var) != rpc_vars.end()) { std::find(rpc_vars.begin(), rpc_vars.end(), var) != rpc_vars.end()) {
return true; return true;
...@@ -130,13 +139,8 @@ bool MultiDevSSAGraphBuilder::IsDistTrainOp( ...@@ -130,13 +139,8 @@ bool MultiDevSSAGraphBuilder::IsDistTrainOp(
return false; return false;
}; };
if (op.Type() == "split" || op.Type() == "split_byref" || return checker(op.OutputArgumentNames(), send_vars) ||
op.Type() == "split_selected_rows") { checker(op.InputArgumentNames(), recv_vars);
return checker(op.OutputArgumentNames(), send_vars);
} else if (op.Type() == "concat") {
return checker(op.InputArgumentNames(), recv_vars);
}
return false; return false;
} }
......
...@@ -34,7 +34,7 @@ bool RPCClient::AsyncSendVariable(const std::string& ep, ...@@ -34,7 +34,7 @@ bool RPCClient::AsyncSendVariable(const std::string& ep,
const std::string ep_val = ep; const std::string ep_val = ep;
const std::string var_name_val = var_name; const std::string var_name_val = var_name;
const framework::Scope* p_scope = &scope; const framework::Scope* p_scope = &scope;
const auto ch = GetChannel(ep_val, ep_val + ":" + var_name_val); const auto ch = GetChannel(ep_val);
framework::AsyncIO([var_name_val, p_ctx, ep_val, p_scope, time_out, ch, framework::AsyncIO([var_name_val, p_ctx, ep_val, p_scope, time_out, ch,
this] { this] {
...@@ -88,7 +88,7 @@ bool RPCClient::AsyncGetVariable(const std::string& ep, ...@@ -88,7 +88,7 @@ bool RPCClient::AsyncGetVariable(const std::string& ep,
const std::string ep_val = ep; const std::string ep_val = ep;
const std::string var_name_val = var_name; const std::string var_name_val = var_name;
const framework::Scope* p_scope = &scope; const framework::Scope* p_scope = &scope;
const auto ch = GetChannel(ep_val, ep_val + ":" + var_name_val); const auto ch = GetChannel(ep_val);
framework::AsyncIO([var_name_val, ep_val, p_scope, p_ctx, time_out, ch, framework::AsyncIO([var_name_val, ep_val, p_scope, p_ctx, time_out, ch,
this] { this] {
...@@ -132,7 +132,7 @@ bool RPCClient::AsyncPrefetchVariable(const std::string& ep, ...@@ -132,7 +132,7 @@ bool RPCClient::AsyncPrefetchVariable(const std::string& ep,
const std::string in_var_name_val = in_var_name; const std::string in_var_name_val = in_var_name;
const std::string out_var_name_val = out_var_name; const std::string out_var_name_val = out_var_name;
const framework::Scope* p_scope = &scope; const framework::Scope* p_scope = &scope;
const auto ch = GetChannel(ep_val, ep_val + ":" + in_var_name_val); const auto ch = GetChannel(ep_val);
framework::AsyncIO([in_var_name_val, out_var_name_val, ep_val, p_scope, p_ctx, framework::AsyncIO([in_var_name_val, out_var_name_val, ep_val, p_scope, p_ctx,
time_out, ch, this] { time_out, ch, this] {
...@@ -165,7 +165,7 @@ bool RPCClient::AsyncPrefetchVariable(const std::string& ep, ...@@ -165,7 +165,7 @@ bool RPCClient::AsyncPrefetchVariable(const std::string& ep,
} }
void RPCClient::AsyncSendBatchBarrier(const std::string& ep, int64_t time_out) { void RPCClient::AsyncSendBatchBarrier(const std::string& ep, int64_t time_out) {
const auto ch = GetChannel(ep, ep); const auto ch = GetChannel(ep);
BatchBarrierProcessor* s = new BatchBarrierProcessor(ch); BatchBarrierProcessor* s = new BatchBarrierProcessor(ch);
s->Prepare(time_out); s->Prepare(time_out);
...@@ -178,7 +178,7 @@ void RPCClient::AsyncSendBatchBarrier(const std::string& ep, int64_t time_out) { ...@@ -178,7 +178,7 @@ void RPCClient::AsyncSendBatchBarrier(const std::string& ep, int64_t time_out) {
} }
void RPCClient::AsyncSendFetchBarrier(const std::string& ep, int64_t time_out) { void RPCClient::AsyncSendFetchBarrier(const std::string& ep, int64_t time_out) {
const auto ch = GetChannel(ep, ep); const auto ch = GetChannel(ep);
FetchBarrierProcessor* s = new FetchBarrierProcessor(ch); FetchBarrierProcessor* s = new FetchBarrierProcessor(ch);
s->Prepare(time_out); s->Prepare(time_out);
...@@ -248,10 +248,9 @@ bool RPCClient::Proceed() { ...@@ -248,10 +248,9 @@ bool RPCClient::Proceed() {
delete c; delete c;
return true; return true;
} }
std::shared_ptr<grpc::Channel> RPCClient::GetChannel(const std::string& ep, std::shared_ptr<grpc::Channel> RPCClient::GetChannel(const std::string& ep) {
const std::string& key) {
std::unique_lock<std::mutex> lock(mutex_); std::unique_lock<std::mutex> lock(mutex_);
auto it = channels_.find(key); auto it = channels_.find(ep);
if (it != channels_.end()) { if (it != channels_.end()) {
return it->second; return it->second;
} }
...@@ -263,7 +262,7 @@ std::shared_ptr<grpc::Channel> RPCClient::GetChannel(const std::string& ep, ...@@ -263,7 +262,7 @@ std::shared_ptr<grpc::Channel> RPCClient::GetChannel(const std::string& ep,
auto ch = auto ch =
grpc::CreateCustomChannel(ep, grpc::InsecureChannelCredentials(), args); grpc::CreateCustomChannel(ep, grpc::InsecureChannelCredentials(), args);
channels_[key] = ch; channels_[ep] = ch;
return ch; return ch;
} }
......
...@@ -191,8 +191,7 @@ class RPCClient { ...@@ -191,8 +191,7 @@ class RPCClient {
private: private:
bool Proceed(); bool Proceed();
std::shared_ptr<grpc::Channel> GetChannel(const std::string& ep, std::shared_ptr<grpc::Channel> GetChannel(const std::string& ep);
const std::string& key);
private: private:
grpc::CompletionQueue cq_; grpc::CompletionQueue cq_;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册