From f6d99d1f73f1b57ee94a0db7fb6c039ff72085de Mon Sep 17 00:00:00 2001 From: Xin Pan Date: Fri, 20 Jul 2018 15:15:21 +0800 Subject: [PATCH] polish --- .../details/multi_devices_graph_builder.cc | 21 --- .../framework/details/ssa_graph_builder.cc | 42 ----- .../framework/details/ssa_graph_builder.h | 9 - paddle/fluid/framework/ir/graph.cc | 155 +----------------- 4 files changed, 7 insertions(+), 220 deletions(-) diff --git a/paddle/fluid/framework/details/multi_devices_graph_builder.cc b/paddle/fluid/framework/details/multi_devices_graph_builder.cc index 4050424e7..f5e99c574 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_builder.cc +++ b/paddle/fluid/framework/details/multi_devices_graph_builder.cc @@ -216,21 +216,6 @@ std::vector SortOpsAndDelayOptimizeOp(const ir::Graph &graph) { sorted_ret.insert(sorted_ret.begin() + last_backward, optimize_ops.begin(), optimize_ops.end()); - - for (ir::Node *n : sorted_ret) { - n->inputs.erase(std::remove_if(n->inputs.begin(), n->inputs.end(), - [n](ir::Node *t) { - return t->Name() == - ir::Node::kControlDepVarName; - }), - n->inputs.end()); - n->outputs.erase(std::remove_if(n->outputs.begin(), n->outputs.end(), - [n](ir::Node *t) { - return t->Name() == - ir::Node::kControlDepVarName; - }), - n->outputs.end()); - } return sorted_ret; } @@ -365,12 +350,6 @@ std::unique_ptr MultiDevSSAGraphBuilder::Apply( } } - /* - Dependency graph has been constructed. However, there are still data - hazards need to be handled. - */ - PolishGraphToSupportDataHazards(&result); - /* * Only variables should be the leaves of graph. */ diff --git a/paddle/fluid/framework/details/ssa_graph_builder.cc b/paddle/fluid/framework/details/ssa_graph_builder.cc index 3c579f427..dcdcb28ac 100644 --- a/paddle/fluid/framework/details/ssa_graph_builder.cc +++ b/paddle/fluid/framework/details/ssa_graph_builder.cc @@ -17,48 +17,6 @@ namespace paddle { namespace framework { namespace details { -void SSAGraphBuilder::PolishGraphToSupportDataHazards(ir::Graph *graph) { - for (auto &var_map : graph->Get("vars")) { - for (auto &name_pair : var_map) { - if (name_pair.second.size() <= 1) { - continue; - } - auto it_new = name_pair.second.rbegin(); - auto it_old = name_pair.second.rbegin(); - ++it_old; - for (; it_old != name_pair.second.rend(); it_new = it_old, ++it_old) { - OpHandleBase *write_op = (*it_new)->GeneratedOp(); - const auto &read_ops = (*it_old)->PendingOps(); - - for (auto *read_op : read_ops) { - // Manually add a dependency var from read_op to write_op; - if (read_op == write_op) { - // Read Write is the same op. - continue; - } - - bool has_dep = false; - for (auto read_out : read_op->Outputs()) { - for (auto write_in : write_op->Inputs()) { - if (read_out == write_in) { - has_dep = true; - break; - } - } - } - if (has_dep) continue; - - auto *dep_var = new DummyVarHandle( - graph->CreateEmptyNode("dummy", ir::Node::Type::kVariable)); - read_op->AddOutput(dep_var); - write_op->AddInput(dep_var); - graph->Get("dep_vars").emplace(dep_var); - } - } - } - } -} - VarHandle *SSAGraphBuilder::CreateOrGetLatestVarHandle( ir::Graph *graph, ir::Node *node, const platform::Place &place, size_t place_offset) { diff --git a/paddle/fluid/framework/details/ssa_graph_builder.h b/paddle/fluid/framework/details/ssa_graph_builder.h index f64445b47..e99bab518 100644 --- a/paddle/fluid/framework/details/ssa_graph_builder.h +++ b/paddle/fluid/framework/details/ssa_graph_builder.h @@ -57,15 +57,6 @@ class SSAGraphBuilder : public ir::Pass { DISABLE_COPY_AND_ASSIGN(SSAGraphBuilder); protected: - /** - * We only handle write after read(WAR), since it should not have a write - * after write in program. If there are write after write operators, we need - * prune them. - * - * https://en.wikipedia.org/wiki/Hazard_(computer_architecture)#Write_after_read_(WAR) - */ - static void PolishGraphToSupportDataHazards(ir::Graph *graph); - static VarHandle *CreateOrGetLatestVarHandle(ir::Graph *graph, ir::Node *node, const platform::Place &place, size_t place_offset); diff --git a/paddle/fluid/framework/ir/graph.cc b/paddle/fluid/framework/ir/graph.cc index 18211f2e2..769dddbc5 100644 --- a/paddle/fluid/framework/ir/graph.cc +++ b/paddle/fluid/framework/ir/graph.cc @@ -23,39 +23,6 @@ limitations under the License. */ namespace paddle { namespace framework { namespace ir { -/* -namespace { -void SortHelper( - const std::map> &adj_list, - ir::Node *node, std::unordered_set *visited, - std::vector *ret) { - visited->insert(node); - - for (auto adj : adj_list.at(node)) { - if (visited->find(adj) == visited->end()) { - SortHelper(adj_list, adj, visited, ret); - } - } - - VLOG(3) << "topology sort insert: " << node->Name() - << reinterpret_cast(node) << " input " << node->inputs.size(); - ret->push_back(node); -} - -std::vector TopologySortOperations( - const std::map> &adj_list) { - std::unordered_set visited; - std::vector ret; - - for (auto adj : adj_list) { - if (visited.find(adj.first) == visited.end()) { - SortHelper(adj_list, adj.first, &visited, &ret); - } - } - return ret; -} -} // namespace -*/ Graph::Graph(const ProgramDesc &program) : program_(program) { VLOG(3) << "block in program:" << program_.Size(); @@ -93,6 +60,13 @@ Graph::Graph(const ProgramDesc &program) : program_(program) { var->inputs.push_back(node); } } + /** + * We only handle write after read(WAR), since it should not have a write + * after write in program. If there are write after write operators, we need + * prune them. + * + * https://en.wikipedia.org/wiki/Hazard_(computer_architecture)#Write_after_read_(WAR) + */ for (auto &var : var_nodes) { auto &versions = var.second; if (versions.size() <= 1) continue; @@ -121,121 +95,6 @@ Graph::Graph(const ProgramDesc &program) : program_(program) { } } } - -/* -bool HasCircleHelper(ir::Node* node, - const std::map> -&adj_list, - std::unordered_set* visited, - std::unordered_set* in_trace) { - if (visited->find(node) == visited->end()) { - visited->insert(node); - in_trace->insert(node); - - for (ir::Node *in : adj_list.at(node)) { - if (visited->find(in) == visited->end() && - HasCircleHelper(in, adj_list, visited, in_trace)) { - return true; - } else if (in_trace->find(in) != in_trace->end()) { - return true; - } - } - } - in_trace->erase(node); - return false; -} - -bool HasCircle(const std::map> -&adj_list) { - std::unordered_set visited; - std::unordered_set in_trace; - for (auto& adj : adj_list) { - if (HasCircleHelper(adj.first, adj_list, &visited, &in_trace)) { - return true; - } - } - return false; -} - -std::map> BuildOperationAdjList( - const std::vector &nodes) { - std::map> adj_list; - - for (auto &n : nodes) { - if (n->NodeType() != ir::Node::Type::kOperation) continue; - if (adj_list.find(n) == adj_list.end()) { - adj_list[n] = std::unordered_set(); - } - for (auto &var : n->inputs) { - for (auto &adj_n : var->inputs) { - PADDLE_ENFORCE(adj_n->NodeType() == ir::Node::Type::kOperation); - adj_list[n].insert(adj_n); - LOG(ERROR) << "adj " << adj_n->Name() << reinterpret_cast(adj_n) - << " -> " << n->Name() << reinterpret_cast(n) - << " via " << var->Name() << reinterpret_cast(var); - } - } - } - return adj_list; -} - -std::vector TopologySortOperationsOperationFromInToOut( - const std::vector> &nodes) { - std::vector tmp; - for (auto& n : nodes) { - tmp.push_back(n.get()); - } - std::map> adj_list = -BuildOperationAdjList(tmp); - - PADDLE_ENFORCE(!HasCircle(adj_list)); - std::vector ret = TopologySortOperations(adj_list); - - ir::Node *last_backward = nullptr; - std::vector optimize_ops; - for (ir::Node* n : ret) { - if (boost::get( - n->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName())) == - static_cast(OpRole::kBackward)) { - last_backward = n; - } else if (boost::get( - n->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName())) == - static_cast(OpRole::kOptimize)) { - optimize_ops.push_back(n); - } - } - - if (last_backward) { - for (ir::Node *opt_node : optimize_ops) { - ir::Node *dep_var = CreateEmptyNode(ir::Node::kControlDepVarName, - ir::Node::Type::kVariable); - last_backward->outputs.push_back(dep_var); - dep_var->inputs.push_back(last_backward); - opt_node->inputs.push_back(dep_var); - dep_var->outputs.push_back(opt_node); - VLOG(3) << "appending connect: " << last_backward->Name() - << reinterpret_cast(last_backward) << "->" - << opt_node->Name() << reinterpret_cast(opt_node); - } - } - - PADDLE_ENFORCE(!HasCircle(adj_list)); - for (ir::Node *n : ret) { - std::unordered_set dummy; - n->inputs.erase( - std::remove_if(n->inputs.begin(), n->inputs.end(), - [n](ir::Node *t) { - return t->Name() == ir::Node::kControlDepVarName; }), - n->inputs.end()); - n->outputs.erase( - std::remove_if(n->outputs.begin(), n->outputs.end(), - [n](ir::Node *t) { - return t->Name() == ir::Node::kControlDepVarName; }), - n->outputs.end()); - } - return ret; -}*/ - } // namespace ir } // namespace framework } // namespace paddle -- GitLab