提交 268e9dc1 编写于 作者: Y Yancey1989

polish code

上级 ceefbf32
......@@ -84,8 +84,12 @@ void MultiDevSSAGraphBuilder::CreateOpHandleIOs(SSAGraph *result,
std::vector<std::string> MultiDevSSAGraphBuilder::FindDistTrainSendVars(
const ProgramDesc &program) const {
std::vector<std::string> send_vars;
// since parameters are all in block 0,
// it's enough to only scan send ops in block 0
for (auto *op : program.Block(0).AllOps()) {
if (op->Type() == "send_vars" || op->Type() == "send") {
// TODO(Yancey1989): use a graceful method to find send op,
// instead of the the hard code string
if (op->Type() == "send_vars") {
auto op_vars = op->InputArgumentNames();
send_vars.reserve(send_vars.size() +
std::distance(op_vars.begin(), op_vars.end()));
......@@ -99,7 +103,9 @@ std::vector<std::string> MultiDevSSAGraphBuilder::FindDistTrainRecvVars(
const ProgramDesc &program) const {
std::vector<std::string> recv_vars;
for (auto *op : program.Block(0).AllOps()) {
if (op->Type() == "recv" || op->Type() == "send") {
// TODO(Yancey1989): use a graceful method to find recv op,
// instead of the hard code string
if (op->Type() == "recv") {
auto op_vars = op->OutputArgumentNames();
recv_vars.reserve(recv_vars.size() +
std::distance(op_vars.begin(), op_vars.end()));
......@@ -122,6 +128,9 @@ bool MultiDevSSAGraphBuilder::IsDistTrainOp(
auto checker = [](const std::vector<std::string> &opvars,
const std::vector<std::string> &rpc_vars) -> bool {
for (auto &var : opvars) {
// a variable name with the suffix `.block` means it's a splited
// variable by (DistributeTranspiler)
// [python/paddle/fluid/transpiler/distribute_transpiler.py]
if (var.find(".block") != std::string::npos &&
std::find(rpc_vars.begin(), rpc_vars.end(), var) != rpc_vars.end()) {
return true;
......@@ -130,13 +139,8 @@ bool MultiDevSSAGraphBuilder::IsDistTrainOp(
return false;
};
if (op.Type() == "split" || op.Type() == "split_byref" ||
op.Type() == "split_selected_rows") {
return checker(op.OutputArgumentNames(), send_vars);
} else if (op.Type() == "concat") {
return checker(op.InputArgumentNames(), recv_vars);
}
return checker(op.OutputArgumentNames(), send_vars) ||
checker(op.InputArgumentNames(), recv_vars);
return false;
}
......
......@@ -34,7 +34,7 @@ bool RPCClient::AsyncSendVariable(const std::string& ep,
const std::string ep_val = ep;
const std::string var_name_val = var_name;
const framework::Scope* p_scope = &scope;
const auto ch = GetChannel(ep_val, ep_val + ":" + var_name_val);
const auto ch = GetChannel(ep_val);
framework::AsyncIO([var_name_val, p_ctx, ep_val, p_scope, time_out, ch,
this] {
......@@ -88,7 +88,7 @@ bool RPCClient::AsyncGetVariable(const std::string& ep,
const std::string ep_val = ep;
const std::string var_name_val = var_name;
const framework::Scope* p_scope = &scope;
const auto ch = GetChannel(ep_val, ep_val + ":" + var_name_val);
const auto ch = GetChannel(ep_val);
framework::AsyncIO([var_name_val, ep_val, p_scope, p_ctx, time_out, ch,
this] {
......@@ -132,7 +132,7 @@ bool RPCClient::AsyncPrefetchVariable(const std::string& ep,
const std::string in_var_name_val = in_var_name;
const std::string out_var_name_val = out_var_name;
const framework::Scope* p_scope = &scope;
const auto ch = GetChannel(ep_val, ep_val + ":" + in_var_name_val);
const auto ch = GetChannel(ep_val);
framework::AsyncIO([in_var_name_val, out_var_name_val, ep_val, p_scope, p_ctx,
time_out, ch, this] {
......@@ -165,7 +165,7 @@ bool RPCClient::AsyncPrefetchVariable(const std::string& ep,
}
void RPCClient::AsyncSendBatchBarrier(const std::string& ep, int64_t time_out) {
const auto ch = GetChannel(ep, ep);
const auto ch = GetChannel(ep);
BatchBarrierProcessor* s = new BatchBarrierProcessor(ch);
s->Prepare(time_out);
......@@ -178,7 +178,7 @@ void RPCClient::AsyncSendBatchBarrier(const std::string& ep, int64_t time_out) {
}
void RPCClient::AsyncSendFetchBarrier(const std::string& ep, int64_t time_out) {
const auto ch = GetChannel(ep, ep);
const auto ch = GetChannel(ep);
FetchBarrierProcessor* s = new FetchBarrierProcessor(ch);
s->Prepare(time_out);
......@@ -248,10 +248,9 @@ bool RPCClient::Proceed() {
delete c;
return true;
}
std::shared_ptr<grpc::Channel> RPCClient::GetChannel(const std::string& ep,
const std::string& key) {
std::shared_ptr<grpc::Channel> RPCClient::GetChannel(const std::string& ep) {
std::unique_lock<std::mutex> lock(mutex_);
auto it = channels_.find(key);
auto it = channels_.find(ep);
if (it != channels_.end()) {
return it->second;
}
......@@ -263,7 +262,7 @@ std::shared_ptr<grpc::Channel> RPCClient::GetChannel(const std::string& ep,
auto ch =
grpc::CreateCustomChannel(ep, grpc::InsecureChannelCredentials(), args);
channels_[key] = ch;
channels_[ep] = ch;
return ch;
}
......
......@@ -191,8 +191,7 @@ class RPCClient {
private:
bool Proceed();
std::shared_ptr<grpc::Channel> GetChannel(const std::string& ep,
const std::string& key);
std::shared_ptr<grpc::Channel> GetChannel(const std::string& ep);
private:
grpc::CompletionQueue cq_;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册