提交 25706d08 编写于 作者: X Xin Pan

properly set up dep of concat and fetch_bar

上级 398cfb47
......@@ -24,6 +24,68 @@ namespace paddle {
namespace framework {
namespace ir {
std::vector<std::string> FindDistTrainSendVars(
const std::vector<ir::Node *> &nodes) {
std::vector<std::string> send_vars;
// since parameters are all in block 0,
// it's enough to only scan send ops in block 0
for (auto &node : nodes) {
auto op_vars = node->Op()->InputArgumentNames();
send_vars.reserve(send_vars.size() +
std::distance(op_vars.begin(), op_vars.end()));
send_vars.insert(send_vars.end(), op_vars.begin(), op_vars.end());
}
return send_vars;
}
std::vector<std::string> FindDistTrainRecvVars(
const std::vector<ir::Node *> &nodes) {
std::vector<std::string> recv_vars;
for (auto &node : nodes) {
auto op_vars = node->Op()->OutputArgumentNames();
recv_vars.reserve(recv_vars.size() +
std::distance(op_vars.begin(), op_vars.end()));
recv_vars.insert(recv_vars.end(), op_vars.begin(), op_vars.end());
}
return recv_vars;
}
bool IsDistTrainOp(ir::Node *node, const std::vector<std::string> &send_vars,
const std::vector<std::string> &recv_vars) {
if (send_vars.size() == 0 || recv_vars.size() == 0) {
return false;
}
/**
* Check any of opvars contains `.block` and in sendvars
*/
auto checker = [](const std::vector<std::string> &opvars,
const std::vector<std::string> &rpc_vars) -> bool {
for (auto &var : opvars) {
// a variable name with the suffix `.block` means it's a splited
// variable by (DistributeTranspiler)
// [python/paddle/fluid/transpiler/distribute_transpiler.py]
if (var.find(".block") != std::string::npos &&
std::find(rpc_vars.begin(), rpc_vars.end(), var) != rpc_vars.end()) {
return true;
}
}
return false;
};
std::vector<std::string> input_var_names;
std::vector<std::string> output_var_names;
for (ir::Node *input : node->inputs) {
input_var_names.push_back(input->Name());
}
for (ir::Node *output : node->outputs) {
output_var_names.push_back(output->Name());
}
return checker(output_var_names, send_vars) ||
checker(input_var_names, recv_vars);
}
Graph::Graph(const ProgramDesc &program) : program_(program) {
VLOG(3) << "block in program:" << program_.Size();
std::unordered_map<std::string, VarDesc *> all_vars;
......@@ -104,6 +166,21 @@ Graph::Graph(const ProgramDesc &program) : program_(program) {
dep_var->outputs.push_back(fetch_bar);
}
}
std::vector<std::string> send_vars = FindDistTrainSendVars(send_ops);
std::vector<std::string> recv_vars = FindDistTrainRecvVars(recv_ops);
for (ir::Node *node : Nodes()) {
if (IsDistTrainOp(node, send_vars, recv_vars)) {
if (fetch_bar && node->Name() == "concat") {
ir::Node *dep_var = CreateControlDepVar();
fetch_bar->outputs.push_back(dep_var);
dep_var->inputs.push_back(fetch_bar);
node->inputs.push_back(dep_var);
dep_var->outputs.push_back(node);
}
}
}
/**
* We only handle write after read(WAR), since it should not have a write
* after write in program. If there are write after write operators, we need
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册