diff --git a/paddle/fluid/framework/details/multi_devices_graph_builder.cc b/paddle/fluid/framework/details/multi_devices_graph_builder.cc index f1f8674caf663ce38df5a2eecbcf690b5ca87dc4..dc9183e96aa6ac898e24e162177a1865a097ab1b 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_builder.cc +++ b/paddle/fluid/framework/details/multi_devices_graph_builder.cc @@ -221,15 +221,15 @@ std::unique_ptr MultiDevSSAGraphBuilder::Apply( // forward, backward nodes. E.g. you can't append an forward node // at the end of the node list. // TODO(panyx0718): FIXME: Needs to sort by forward->backward order. - for (auto &node : nodes) { - if (node->NodeType() != ir::Node::Type::kOperation) continue; + for (ir::Node *node : TopologySortOperationFromInToOut(nodes)) { + VLOG(3) << "apply node: " << node->Name() << reinterpret_cast(node); if (boost::get( node->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName())) == static_cast(OpRole::kRPC)) { - CreateRPCOp(&result, node.get()); - } else if (IsDistTrainOp(node.get(), send_vars, recv_vars)) { - CreateDistTrainOp(&result, node.get()); - } else if (IsScaleLossOp(node.get())) { + CreateRPCOp(&result, node); + } else if (IsDistTrainOp(node, send_vars, recv_vars)) { + CreateDistTrainOp(&result, node); + } else if (IsScaleLossOp(node)) { // user can customize loss@grad if not use_default_grad_scale_ if (strategy_.gradient_scale_ != BuildStrategy::GradientScaleStrategy::kCustomized) { @@ -240,10 +240,11 @@ std::unique_ptr MultiDevSSAGraphBuilder::Apply( // It also assumes backward op will always follow the forward op in // the block. is_forwarding = false; + LOG(ERROR) << "forward flipping!!!!!!!"; } else { - int op_dev_id = GetOpDeviceID(node.get()); + int op_dev_id = GetOpDeviceID(node); if (op_dev_id != -1) { // This op only runs on one specific device. - CreateComputationalOp(&result, node.get(), op_dev_id); + CreateComputationalOp(&result, node, op_dev_id); for (ir::Node *n : node->outputs) { var_name_on_devices_.emplace(n->Name(), op_dev_id); } @@ -252,13 +253,11 @@ std::unique_ptr MultiDevSSAGraphBuilder::Apply( // gradients. if (node->Op()->Type() == "read" && strategy_.enable_data_balance_) { node->Op()->SetAttr("throw_eof_exp", false); - CreateComputationalOps(&result, node.get(), places_.size()); - // TODO(paddle-dev): builder shouldn't depend on the out logic of - // a specific op. + CreateComputationalOps(&result, node, places_.size()); const auto &data_var_names = node->Op()->Output("Out"); InsertDataBalanceOp(&result, data_var_names); } else { - CreateComputationalOps(&result, node.get(), places_.size()); + CreateComputationalOps(&result, node, places_.size()); } if (!is_forwarding && places_.size() > 1) { @@ -479,8 +478,8 @@ int MultiDevSSAGraphBuilder::GetOpDeviceID(ir::Node *node) const { PADDLE_ENFORCE_EQ(param_grad.size(), 2U); int dev_id = GetVarDeviceID(param_grad[1]); - PADDLE_ENFORCE_NE(dev_id, -1, "dev_id should not be -1.[%s, %s]", - node->Op()->Type(), param_grad[0]); + PADDLE_ENFORCE_NE(dev_id, -1, "dev_id should not be -1.[%s, %s, %s]", + node->Op()->Type(), param_grad[0], param_grad[1]); return dev_id; } diff --git a/paddle/fluid/framework/details/ssa_graph_builder.cc b/paddle/fluid/framework/details/ssa_graph_builder.cc index 7bc130ef6e8d2e0caf6e445d12950b87e6dd4dbd..2be4bb009eff2866f39f08f11052822eb1fdea5a 100644 --- a/paddle/fluid/framework/details/ssa_graph_builder.cc +++ b/paddle/fluid/framework/details/ssa_graph_builder.cc @@ -37,6 +37,17 @@ void SSAGraphBuilder::PolishGraphToSupportDataHazards(Graph *graph) { continue; } + bool has_dep = false; + for (auto read_out : read_op->Outputs()) { + for (auto write_in : write_op->Inputs()) { + if (read_out == write_in) { + has_dep = true; + break; + } + } + } + if (has_dep) continue; + auto *dep_var = new DummyVarHandle( graph->CreateEmptyNode("dummy", ir::Node::Type::kVariable)); read_op->AddOutput(dep_var); diff --git a/paddle/fluid/framework/ir/graph.cc b/paddle/fluid/framework/ir/graph.cc index e4021aa92b6da2343b604fb7bc01d31edb97d842..f297461ab27df653a529b2d08320a8bf95daac9c 100644 --- a/paddle/fluid/framework/ir/graph.cc +++ b/paddle/fluid/framework/ir/graph.cc @@ -12,14 +12,35 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include +#include + #include "paddle/fluid/framework/ir/graph.h" +#include "paddle/fluid/framework/op_proto_maker.h" #include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/var_desc.h" namespace paddle { namespace framework { +namespace { +void SortHelper( + const std::map> &adj_list, + ir::Node *node, std::unordered_set *visited, + std::vector *ret) { + visited->insert(node); + + for (auto adj : adj_list.at(node)) { + if (visited->find(adj) == visited->end()) { + SortHelper(adj_list, adj, visited, ret); + } + } + + VLOG(3) << "topology sort insert: " << node->Name() + << reinterpret_cast(node) << " input " << node->inputs.size(); + ret->push_back(node); +} +} // namespace -// NOTE(paddle-dev): This graph contains circle. Graph::Graph(const ProgramDesc &program) : program_(program) { VLOG(3) << "block in program:" << program_.Size(); std::unordered_map all_vars; @@ -27,40 +48,128 @@ Graph::Graph(const ProgramDesc &program) : program_(program) { all_vars.emplace(var->Name(), var); } - std::map var_nodes; + ir::Node *last_backward = nullptr; + std::vector optimize_ops; + std::map> var_nodes; for (auto *op : program.Block(0).AllOps()) { ir::Node *node = CreateOpNode(op); + if (boost::get( + op->GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName())) == + static_cast(OpRole::kBackward)) { + last_backward = node; + } else if (boost::get( + op->GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName())) == + static_cast(OpRole::kOptimize)) { + optimize_ops.push_back(node); + } for (auto &each_var_name : op->InputArgumentNames()) { ir::Node *var = nullptr; if (var_nodes.find(each_var_name) != var_nodes.end()) { - var = var_nodes.at(each_var_name); + var = var_nodes.at(each_var_name).back(); } else if (all_vars.count(each_var_name) != 0) { var = CreateVarNode(all_vars.at(each_var_name)); - var_nodes[each_var_name] = var; + var_nodes[each_var_name].push_back(var); } else { // TODO(paddle-dev): Seems some assumption doesn't hold? VLOG(3) << op->Type() << " input var not in all_var list: " << each_var_name; var = CreateEmptyNode(each_var_name, ir::Node::Type::kVariable); - var_nodes[each_var_name] = var; + var_nodes[each_var_name].push_back(var); } node->inputs.push_back(var); var->outputs.push_back(node); } for (auto &each_var_name : op->OutputArgumentNames()) { - ir::Node *var = nullptr; - if (var_nodes.find(each_var_name) != var_nodes.end()) { - var = var_nodes.at(each_var_name); - } else { - var = CreateVarNode(all_vars.at(each_var_name)); - var_nodes[each_var_name] = var; - } + ir::Node *var = CreateVarNode(all_vars.at(each_var_name)); + var_nodes[each_var_name].push_back(var); node->outputs.push_back(var); var->inputs.push_back(node); } } + for (auto &var : var_nodes) { + auto &versions = var.second; + if (versions.size() <= 1) continue; + + auto it_new = versions.rbegin(); + auto it_old = versions.rbegin(); + ++it_old; + for (; it_old != versions.rend(); it_new = it_old, ++it_old) { + ir::Node *write_op = + (*it_new)->inputs.empty() ? nullptr : (*it_new)->inputs[0]; + const auto &read_ops = (*it_old)->outputs; + + for (auto *read_op : read_ops) { + // Manually add a dependency var from read_op to write_op; + if (read_op == write_op) { + // Read Write is the same op. + continue; + } + ir::Node *dep_var = CreateEmptyNode("dummy", ir::Node::Type::kVariable); + read_op->outputs.push_back(dep_var); + dep_var->inputs.push_back(read_op); + write_op->inputs.push_back(dep_var); + dep_var->outputs.push_back(write_op); + } + } + } + + if (last_backward) { + for (ir::Node *opt_node : optimize_ops) { + ir::Node *dep_var = CreateEmptyNode("dummy", ir::Node::Type::kVariable); + last_backward->outputs.push_back(dep_var); + dep_var->inputs.push_back(last_backward); + opt_node->inputs.push_back(dep_var); + dep_var->outputs.push_back(opt_node); + VLOG(3) << "appending connect: " << last_backward->Name() + << reinterpret_cast(last_backward) << "->" + << opt_node->Name() << reinterpret_cast(opt_node); + } + } +} + +std::vector TopologySortOperationFromInToOut( + const std::vector> &nodes) { + std::map> adj_list; + std::unordered_set visited; + std::vector ret; + + for (auto &n : nodes) { + if (n->NodeType() != ir::Node::Type::kOperation) continue; + if (adj_list.find(n.get()) == adj_list.end()) { + adj_list[n.get()] = std::unordered_set(); + } + for (auto &var : n->inputs) { + for (auto &adj_n : var->inputs) { + PADDLE_ENFORCE(adj_n->NodeType() == ir::Node::Type::kOperation); + adj_list[n.get()].insert(adj_n); + LOG(ERROR) << "adj " << adj_n->Name() << reinterpret_cast(adj_n) + << " -> " << n->Name() << reinterpret_cast(n.get()) + << " via " << var->Name() << reinterpret_cast(var); + } + } + } + + for (auto adj : adj_list) { + if (visited.find(adj.first) == visited.end()) { + SortHelper(adj_list, adj.first, &visited, &ret); + } + } + + for (ir::Node *n : ret) { + std::unordered_set dummy; + n->inputs.erase( + std::remove_if(n->inputs.begin(), n->inputs.end(), + [n](ir::Node *t) { return t->Name() == "dummy"; }), + n->inputs.end()); + n->outputs.erase( + std::remove_if(n->outputs.begin(), n->outputs.end(), + [n](ir::Node *t) { return t->Name() == "dummy"; }), + n->outputs.end()); + } + return ret; } + } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/ir/graph.h b/paddle/fluid/framework/ir/graph.h index b4ac135b029005b723abca2cb9b9a9aa175eda40..0242edecf4525fad45d9203740997035587e7130 100644 --- a/paddle/fluid/framework/ir/graph.h +++ b/paddle/fluid/framework/ir/graph.h @@ -78,5 +78,8 @@ class Graph { std::map> attr_dels_; }; +std::vector TopologySortOperationFromInToOut( + const std::vector>& nodes); + } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/ir/graph_test.cc b/paddle/fluid/framework/ir/graph_test.cc index 4e23bf124f8822e25be0f6b1c7c8c5de4e4f600a..186047b370c778a43a2828d249716ea0bafee39e 100644 --- a/paddle/fluid/framework/ir/graph_test.cc +++ b/paddle/fluid/framework/ir/graph_test.cc @@ -76,6 +76,7 @@ TEST(GraphTest, Basic) { op->SetType("sum"); op->SetInput("X", {"test_a", "test_b", "test_c"}); op->SetOutput("Out", {"test_out"}); + op->SetAttr("op_role", 1); prog.MutableBlock(0)->Var("test_a")->SetType(proto::VarType::SELECTED_ROWS); prog.MutableBlock(0)->Var("test_b")->SetType(proto::VarType::SELECTED_ROWS); diff --git a/paddle/fluid/framework/ir/node.h b/paddle/fluid/framework/ir/node.h index b98c29b81ddc2f57553b8fe76fcfeb0936ddd837..97b64a6017ef08ffc73ae22beb18321934506078 100644 --- a/paddle/fluid/framework/ir/node.h +++ b/paddle/fluid/framework/ir/node.h @@ -50,6 +50,7 @@ class Node { PADDLE_ENFORCE(type_ == Type::kVariable); return var_desc_; } + OpDesc* Op() { PADDLE_ENFORCE(type_ == Type::kOperation); return op_desc_;