From be132719c5cc7a87df1923676be9cf205366f1cd Mon Sep 17 00:00:00 2001 From: pangyoki Date: Wed, 27 Jul 2022 16:36:51 +0800 Subject: [PATCH] fix RemoveIntermediateOut in fuse_elewise_add_act_pass while converting graph to program (#44593) * fix RemoveNode in fuse_elewise_add_act_pass * fix * change pointer to share_ptr * fix * fix * fix format * fix * fix graph_safe_remove_nodes --- .../framework/ir/fuse_elewise_add_act_pass.cc | 13 ++++++++++- paddle/fluid/framework/ir/graph.h | 2 ++ paddle/fluid/framework/ir/graph_helper.cc | 22 ++++++++++++++----- .../framework/ir/graph_pattern_detector.cc | 14 +++++++++--- .../framework/ir/graph_pattern_detector.h | 6 +++-- paddle/fluid/pybind/ir.cc | 5 ++++- 6 files changed, 50 insertions(+), 12 deletions(-) diff --git a/paddle/fluid/framework/ir/fuse_elewise_add_act_pass.cc b/paddle/fluid/framework/ir/fuse_elewise_add_act_pass.cc index 5bd26e9eb9..67aa5a822e 100644 --- a/paddle/fluid/framework/ir/fuse_elewise_add_act_pass.cc +++ b/paddle/fluid/framework/ir/fuse_elewise_add_act_pass.cc @@ -297,7 +297,18 @@ void FuseElewiseAddActPass::RemoveIntermediateOut(Graph *graph) const { } } } - GraphSafeRemoveNodes(graph, need_removed_nodes); + details::RemovedVars *saved_removed_nodes = new details::RemovedVars; + GraphSafeRemoveNodes(graph, need_removed_nodes, saved_removed_nodes); + if (!saved_removed_nodes->empty()) { + // TODO(pangyoki): If kRemovedVars exists, merge saved_removed_nodes into + // RemovedVars. + PADDLE_ENFORCE_EQ(graph->Has(details::kRemovedVars), + false, + platform::errors::PreconditionNotMet( + "Removed nodes are only saved for " + "fuse_elewise_add_act_pass in temporary.")); + graph->Set(details::kRemovedVars, saved_removed_nodes); + } } void FuseElewiseAddActPass::ReLinkNodes(Graph *graph, diff --git a/paddle/fluid/framework/ir/graph.h b/paddle/fluid/framework/ir/graph.h index 5a95411077..3eb2df7011 100644 --- a/paddle/fluid/framework/ir/graph.h +++ b/paddle/fluid/framework/ir/graph.h @@ -45,6 +45,8 @@ namespace details { // This attr is not recommended, because the graph should not dependence // the program once it is built. constexpr char kStaleProgramOpDescs[] = "stale_program_op_descs"; +constexpr char kRemovedVars[] = "removed_vars"; +typedef std::unordered_set> RemovedVars; } // namespace details namespace ir { diff --git a/paddle/fluid/framework/ir/graph_helper.cc b/paddle/fluid/framework/ir/graph_helper.cc index 80568b7766..a7bf131805 100644 --- a/paddle/fluid/framework/ir/graph_helper.cc +++ b/paddle/fluid/framework/ir/graph_helper.cc @@ -549,6 +549,18 @@ static void GetGraphOpDesc(const std::vector &nodes, } } +template +static void GetGraphVarDesc(const Graph &graph, + const std::unordered_set &nodes, + std::vector *vars) { + for (T node : nodes) { + if (node->IsVar() && node->Var() && + node->GetVarNodeBlockId() == graph.GetBlockId()) { + vars->emplace_back(*node->Var()->Proto()); + } + } +} + static void GraphToBlock(const Graph &graph, proto::BlockDesc *block, const SortKind *sort_kind) { @@ -562,11 +574,11 @@ static void GraphToBlock(const Graph &graph, } std::vector vars_in_graph; - for (Node *node : graph.Nodes()) { - if (node->IsVar() && node->Var() && - node->GetVarNodeBlockId() == graph.GetBlockId()) { - vars_in_graph.emplace_back(*node->Var()->Proto()); - } + GetGraphVarDesc(graph, graph.Nodes(), &vars_in_graph); + if (graph.Has(details::kRemovedVars)) { + auto &removed_vars = graph.Get(details::kRemovedVars); + GetGraphVarDesc>( + graph, removed_vars, &vars_in_graph); } // add vars_in_graph to blcok diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.cc b/paddle/fluid/framework/ir/graph_pattern_detector.cc index 6191c2efe9..cce1ec89a2 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.cc +++ b/paddle/fluid/framework/ir/graph_pattern_detector.cc @@ -771,10 +771,18 @@ bool IsNthOutput(Node *var, Node *op, const std::string &argument, size_t nth) { return var->Name() == op->Op()->Output(argument)[nth]; } -void GraphSafeRemoveNodes(Graph *graph, - const std::unordered_set &nodes) { +void GraphSafeRemoveNodes( + Graph *graph, + const std::unordered_set &nodes, + std::unordered_set> *saved_nodes) { for (auto *node : nodes) { - graph->RemoveNode(const_cast(node)); + if (saved_nodes != nullptr) { + // prevent unique_ptr node from being released + saved_nodes->insert( + std::move(graph->RemoveNode(const_cast(node)))); + } else { + graph->RemoveNode(const_cast(node)); + } } for (auto *node : graph->Nodes()) { diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.h b/paddle/fluid/framework/ir/graph_pattern_detector.h index 00e565b716..794c25e85a 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.h +++ b/paddle/fluid/framework/ir/graph_pattern_detector.h @@ -392,8 +392,10 @@ bool HasOutput(Node* op, const std::string& argument); bool IsNthOutput(Node* var, Node* op, const std::string& argument, size_t nth); // Graph safely remove some nodes, will automatically clean up the edges. -void GraphSafeRemoveNodes(Graph* graph, - const std::unordered_set& nodes); +void GraphSafeRemoveNodes( + Graph* graph, + const std::unordered_set& nodes, + std::unordered_set>* saved_nodes = nullptr); // Some pre-defined patterns those can be reused in multiple passes. // The related Fluid Layer or Op should be one pattern here for better re-usage diff --git a/paddle/fluid/pybind/ir.cc b/paddle/fluid/pybind/ir.cc index b8b127201c..73f7e9a098 100644 --- a/paddle/fluid/pybind/ir.cc +++ b/paddle/fluid/pybind/ir.cc @@ -50,7 +50,10 @@ using pybind11::return_value_policy; namespace paddle { namespace pybind { void BindGraph(py::module *m) { - m->def("graph_safe_remove_nodes", GraphSafeRemoveNodes); + m->def("graph_safe_remove_nodes", + [](Graph *graph, const std::unordered_set &nodes) { + return GraphSafeRemoveNodes(graph, nodes); + }); m->def("has_circle", HasCircle); m->def("graph_num", GraphNum); m->def( -- GitLab