未验证 提交 be132719 编写于 作者: P pangyoki 提交者: GitHub

fix RemoveIntermediateOut in fuse_elewise_add_act_pass while converting graph to program (#44593)

* fix RemoveNode in fuse_elewise_add_act_pass

* fix

* change pointer to share_ptr

* fix

* fix

* fix format

* fix

* fix graph_safe_remove_nodes
上级 b20f771f
......@@ -297,7 +297,18 @@ void FuseElewiseAddActPass::RemoveIntermediateOut(Graph *graph) const {
}
}
}
GraphSafeRemoveNodes(graph, need_removed_nodes);
details::RemovedVars *saved_removed_nodes = new details::RemovedVars;
GraphSafeRemoveNodes(graph, need_removed_nodes, saved_removed_nodes);
if (!saved_removed_nodes->empty()) {
// TODO(pangyoki): If kRemovedVars exists, merge saved_removed_nodes into
// RemovedVars.
PADDLE_ENFORCE_EQ(graph->Has(details::kRemovedVars),
false,
platform::errors::PreconditionNotMet(
"Removed nodes are only saved for "
"fuse_elewise_add_act_pass in temporary."));
graph->Set(details::kRemovedVars, saved_removed_nodes);
}
}
void FuseElewiseAddActPass::ReLinkNodes(Graph *graph,
......
......@@ -45,6 +45,8 @@ namespace details {
// This attr is not recommended, because the graph should not dependence
// the program once it is built.
constexpr char kStaleProgramOpDescs[] = "stale_program_op_descs";
constexpr char kRemovedVars[] = "removed_vars";
typedef std::unordered_set<std::shared_ptr<ir::Node>> RemovedVars;
} // namespace details
namespace ir {
......
......@@ -549,6 +549,18 @@ static void GetGraphOpDesc(const std::vector<Node *> &nodes,
}
}
template <class T = Node *>
static void GetGraphVarDesc(const Graph &graph,
const std::unordered_set<T> &nodes,
std::vector<proto::VarDesc> *vars) {
for (T node : nodes) {
if (node->IsVar() && node->Var() &&
node->GetVarNodeBlockId() == graph.GetBlockId()) {
vars->emplace_back(*node->Var()->Proto());
}
}
}
static void GraphToBlock(const Graph &graph,
proto::BlockDesc *block,
const SortKind *sort_kind) {
......@@ -562,11 +574,11 @@ static void GraphToBlock(const Graph &graph,
}
std::vector<proto::VarDesc> vars_in_graph;
for (Node *node : graph.Nodes()) {
if (node->IsVar() && node->Var() &&
node->GetVarNodeBlockId() == graph.GetBlockId()) {
vars_in_graph.emplace_back(*node->Var()->Proto());
}
GetGraphVarDesc<Node *>(graph, graph.Nodes(), &vars_in_graph);
if (graph.Has(details::kRemovedVars)) {
auto &removed_vars = graph.Get<details::RemovedVars>(details::kRemovedVars);
GetGraphVarDesc<std::shared_ptr<ir::Node>>(
graph, removed_vars, &vars_in_graph);
}
// add vars_in_graph to blcok
......
......@@ -771,11 +771,19 @@ bool IsNthOutput(Node *var, Node *op, const std::string &argument, size_t nth) {
return var->Name() == op->Op()->Output(argument)[nth];
}
void GraphSafeRemoveNodes(Graph *graph,
const std::unordered_set<const Node *> &nodes) {
void GraphSafeRemoveNodes(
Graph *graph,
const std::unordered_set<const Node *> &nodes,
std::unordered_set<std::shared_ptr<Node>> *saved_nodes) {
for (auto *node : nodes) {
if (saved_nodes != nullptr) {
// prevent unique_ptr node from being released
saved_nodes->insert(
std::move(graph->RemoveNode(const_cast<Node *>(node))));
} else {
graph->RemoveNode(const_cast<Node *>(node));
}
}
for (auto *node : graph->Nodes()) {
for (auto it = node->inputs.begin(); it != node->inputs.end();) {
......
......@@ -392,8 +392,10 @@ bool HasOutput(Node* op, const std::string& argument);
bool IsNthOutput(Node* var, Node* op, const std::string& argument, size_t nth);
// Graph safely remove some nodes, will automatically clean up the edges.
void GraphSafeRemoveNodes(Graph* graph,
const std::unordered_set<const Node*>& nodes);
void GraphSafeRemoveNodes(
Graph* graph,
const std::unordered_set<const Node*>& nodes,
std::unordered_set<std::shared_ptr<Node>>* saved_nodes = nullptr);
// Some pre-defined patterns those can be reused in multiple passes.
// The related Fluid Layer or Op should be one pattern here for better re-usage
......
......@@ -50,7 +50,10 @@ using pybind11::return_value_policy;
namespace paddle {
namespace pybind {
void BindGraph(py::module *m) {
m->def("graph_safe_remove_nodes", GraphSafeRemoveNodes);
m->def("graph_safe_remove_nodes",
[](Graph *graph, const std::unordered_set<const Node *> &nodes) {
return GraphSafeRemoveNodes(graph, nodes);
});
m->def("has_circle", HasCircle);
m->def("graph_num", GraphNum);
m->def(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册