diff --git a/paddle/fluid/framework/ir/graph.cc b/paddle/fluid/framework/ir/graph.cc index 740acfafb7594d8d9f3ca5439323ce76c5ed271a..2cfad606d25ac389bce3eed3a16229e2730737d1 100644 --- a/paddle/fluid/framework/ir/graph.cc +++ b/paddle/fluid/framework/ir/graph.cc @@ -61,6 +61,49 @@ Graph::Graph(const ProgramDesc &program) : program_(program) { var->inputs.push_back(node); } } + + std::vector send_ops; + ir::Node *send_bar = nullptr; + std::vector recv_ops; + ir::Node *fetch_bar = nullptr; + for (ir::Node *node : Nodes()) { + if (node->Name() == "send") { + send_ops.push_back(node); + } else if (node->Name() == "send_barrier") { + PADDLE_ENFORCE(!send_bar, "only has one send barrier"); + send_bar = node; + } else if (node->Name() == "recv") { + recv_ops.push_back(node); + } else if (node->Name() == "fetch_barrier") { + PADDLE_ENFORCE(!fetch_bar, "only has one fetch barrier"); + fetch_bar = node; + } + } + if (send_bar) { + for (ir::Node *send : send_ops) { + ir::Node *dep_var = CreateControlDepVar(); + send->outputs.push_back(dep_var); + dep_var->inputs.push_back(send); + send_bar->inputs.push_back(dep_var); + dep_var->outputs.push_back(send_bar); + } + for (ir::Node *recv : recv_ops) { + ir::Node *dep_var = CreateControlDepVar(); + recv->inputs.push_back(dep_var); + dep_var->outputs.push_back(recv); + send_bar->outputs.push_back(dep_var); + dep_var->inputs.push_back(send_bar); + } + } + if (fetch_bar) { + for (ir::Node *recv : recv_ops) { + ir::Node *dep_var = CreateControlDepVar(); + recv->outputs.push_back(dep_var); + dep_var->inputs.push_back(recv); + fetch_bar->inputs.push_back(dep_var); + dep_var->outputs.push_back(fetch_bar); + } + } /** * We only handle write after read(WAR), since it should not have a write * after write in program. If there are write after write operators, we need