From 5fff8d7a55d37608a64f3ac6fc9b002bb32ed985 Mon Sep 17 00:00:00 2001 From: Xin Pan Date: Fri, 27 Jul 2018 17:04:02 +0800 Subject: [PATCH] add distributed training deps. --- paddle/fluid/framework/ir/graph.cc | 43 ++++++++++++++++++++++++++++++ 1 file changed, 43 insertions(+) diff --git a/paddle/fluid/framework/ir/graph.cc b/paddle/fluid/framework/ir/graph.cc index 740acfafb..2cfad606d 100644 --- a/paddle/fluid/framework/ir/graph.cc +++ b/paddle/fluid/framework/ir/graph.cc @@ -61,6 +61,49 @@ Graph::Graph(const ProgramDesc &program) : program_(program) { var->inputs.push_back(node); } } + + std::vector send_ops; + ir::Node *send_bar = nullptr; + std::vector recv_ops; + ir::Node *fetch_bar = nullptr; + for (ir::Node *node : Nodes()) { + if (node->Name() == "send") { + send_ops.push_back(node); + } else if (node->Name() == "send_barrier") { + PADDLE_ENFORCE(!send_bar, "only has one send barrier"); + send_bar = node; + } else if (node->Name() == "recv") { + recv_ops.push_back(node); + } else if (node->Name() == "fetch_barrier") { + PADDLE_ENFORCE(!fetch_bar, "only has one fetch barrier"); + fetch_bar = node; + } + } + if (send_bar) { + for (ir::Node *send : send_ops) { + ir::Node *dep_var = CreateControlDepVar(); + send->outputs.push_back(dep_var); + dep_var->inputs.push_back(send); + send_bar->inputs.push_back(dep_var); + dep_var->outputs.push_back(send_bar); + } + for (ir::Node *recv : recv_ops) { + ir::Node *dep_var = CreateControlDepVar(); + recv->inputs.push_back(dep_var); + dep_var->outputs.push_back(recv); + send_bar->outputs.push_back(dep_var); + dep_var->inputs.push_back(send_bar); + } + } + if (fetch_bar) { + for (ir::Node *recv : recv_ops) { + ir::Node *dep_var = CreateControlDepVar(); + recv->outputs.push_back(dep_var); + dep_var->inputs.push_back(recv); + fetch_bar->inputs.push_back(dep_var); + dep_var->outputs.push_back(fetch_bar); + } + } /** * We only handle write after read(WAR), since it should not have a write * after write in program. If there are write after write operators, we need -- GitLab