diff --git a/paddle/fluid/framework/ir/fuse_elewise_add_act_pass.cc b/paddle/fluid/framework/ir/fuse_elewise_add_act_pass.cc index 5bd26e9eb9f2d655434eea6a80263e138e8b956e..67aa5a822edae774d8a53c901826d87ede7a841d 100644 --- a/paddle/fluid/framework/ir/fuse_elewise_add_act_pass.cc +++ b/paddle/fluid/framework/ir/fuse_elewise_add_act_pass.cc @@ -297,7 +297,18 @@ void FuseElewiseAddActPass::RemoveIntermediateOut(Graph *graph) const { } } } - GraphSafeRemoveNodes(graph, need_removed_nodes); + details::RemovedVars *saved_removed_nodes = new details::RemovedVars; + GraphSafeRemoveNodes(graph, need_removed_nodes, saved_removed_nodes); + if (!saved_removed_nodes->empty()) { + // TODO(pangyoki): If kRemovedVars exists, merge saved_removed_nodes into + // RemovedVars. + PADDLE_ENFORCE_EQ(graph->Has(details::kRemovedVars), + false, + platform::errors::PreconditionNotMet( + "Removed nodes are only saved for " + "fuse_elewise_add_act_pass in temporary.")); + graph->Set(details::kRemovedVars, saved_removed_nodes); + } } void FuseElewiseAddActPass::ReLinkNodes(Graph *graph, diff --git a/paddle/fluid/framework/ir/graph.h b/paddle/fluid/framework/ir/graph.h index 5a954110775d67c1a6e4d3ba20097a9ee248cca8..3eb2df7011c7ed15b54d427beb4602f4c2cddf4f 100644 --- a/paddle/fluid/framework/ir/graph.h +++ b/paddle/fluid/framework/ir/graph.h @@ -45,6 +45,8 @@ namespace details { // This attr is not recommended, because the graph should not dependence // the program once it is built. constexpr char kStaleProgramOpDescs[] = "stale_program_op_descs"; +constexpr char kRemovedVars[] = "removed_vars"; +typedef std::unordered_set> RemovedVars; } // namespace details namespace ir { diff --git a/paddle/fluid/framework/ir/graph_helper.cc b/paddle/fluid/framework/ir/graph_helper.cc index 80568b7766503b9026de33bfb084b8679679b162..a7bf131805dc143ac967f8ac1cd97fbfe3fdc9a0 100644 --- a/paddle/fluid/framework/ir/graph_helper.cc +++ b/paddle/fluid/framework/ir/graph_helper.cc @@ -549,6 +549,18 @@ static void GetGraphOpDesc(const std::vector &nodes, } } +template +static void GetGraphVarDesc(const Graph &graph, + const std::unordered_set &nodes, + std::vector *vars) { + for (T node : nodes) { + if (node->IsVar() && node->Var() && + node->GetVarNodeBlockId() == graph.GetBlockId()) { + vars->emplace_back(*node->Var()->Proto()); + } + } +} + static void GraphToBlock(const Graph &graph, proto::BlockDesc *block, const SortKind *sort_kind) { @@ -562,11 +574,11 @@ static void GraphToBlock(const Graph &graph, } std::vector vars_in_graph; - for (Node *node : graph.Nodes()) { - if (node->IsVar() && node->Var() && - node->GetVarNodeBlockId() == graph.GetBlockId()) { - vars_in_graph.emplace_back(*node->Var()->Proto()); - } + GetGraphVarDesc(graph, graph.Nodes(), &vars_in_graph); + if (graph.Has(details::kRemovedVars)) { + auto &removed_vars = graph.Get(details::kRemovedVars); + GetGraphVarDesc>( + graph, removed_vars, &vars_in_graph); } // add vars_in_graph to blcok diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.cc b/paddle/fluid/framework/ir/graph_pattern_detector.cc index 6191c2efe9087c13a834e7dc602d8ba5e8da5c84..cce1ec89a2e82e3e4a11a7ebc85823f3663c972a 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.cc +++ b/paddle/fluid/framework/ir/graph_pattern_detector.cc @@ -771,10 +771,18 @@ bool IsNthOutput(Node *var, Node *op, const std::string &argument, size_t nth) { return var->Name() == op->Op()->Output(argument)[nth]; } -void GraphSafeRemoveNodes(Graph *graph, - const std::unordered_set &nodes) { +void GraphSafeRemoveNodes( + Graph *graph, + const std::unordered_set &nodes, + std::unordered_set> *saved_nodes) { for (auto *node : nodes) { - graph->RemoveNode(const_cast(node)); + if (saved_nodes != nullptr) { + // prevent unique_ptr node from being released + saved_nodes->insert( + std::move(graph->RemoveNode(const_cast(node)))); + } else { + graph->RemoveNode(const_cast(node)); + } } for (auto *node : graph->Nodes()) { diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.h b/paddle/fluid/framework/ir/graph_pattern_detector.h index 00e565b7161a2a616b47257bea40879abf9651d8..794c25e85a555fd93f750222afeba8d3896e289a 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.h +++ b/paddle/fluid/framework/ir/graph_pattern_detector.h @@ -392,8 +392,10 @@ bool HasOutput(Node* op, const std::string& argument); bool IsNthOutput(Node* var, Node* op, const std::string& argument, size_t nth); // Graph safely remove some nodes, will automatically clean up the edges. -void GraphSafeRemoveNodes(Graph* graph, - const std::unordered_set& nodes); +void GraphSafeRemoveNodes( + Graph* graph, + const std::unordered_set& nodes, + std::unordered_set>* saved_nodes = nullptr); // Some pre-defined patterns those can be reused in multiple passes. // The related Fluid Layer or Op should be one pattern here for better re-usage diff --git a/paddle/fluid/pybind/ir.cc b/paddle/fluid/pybind/ir.cc index b8b127201cccde2437a9f98f977f6692cfafe798..73f7e9a098c14779124e7c329109daf50a9a80a6 100644 --- a/paddle/fluid/pybind/ir.cc +++ b/paddle/fluid/pybind/ir.cc @@ -50,7 +50,10 @@ using pybind11::return_value_policy; namespace paddle { namespace pybind { void BindGraph(py::module *m) { - m->def("graph_safe_remove_nodes", GraphSafeRemoveNodes); + m->def("graph_safe_remove_nodes", + [](Graph *graph, const std::unordered_set &nodes) { + return GraphSafeRemoveNodes(graph, nodes); + }); m->def("has_circle", HasCircle); m->def("graph_num", GraphNum); m->def(