未验证 提交 2b546613 编写于 作者: Z zhupengyang 提交者: GitHub

refine delete_repeated_ops_pass (#54122)

* Not delete slice op if out op has shared var node
上级 814f1d65
...@@ -32,6 +32,19 @@ class Scope; ...@@ -32,6 +32,19 @@ class Scope;
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace ir { namespace ir {
bool HasOutVarName(Node* op_node, std::string name) {
auto* op_desc = op_node->Op();
auto outputs = op_desc->Outputs();
for (auto iter : outputs) {
auto out_names = iter.second;
if (std::count(out_names.begin(), out_names.end(), name) > 0) {
return true;
}
}
return false;
}
namespace patterns { namespace patterns {
struct VarWithRepeatedOpsPattern : public PatternBase { struct VarWithRepeatedOpsPattern : public PatternBase {
...@@ -106,18 +119,22 @@ int DeleteRepeatedOpsPass::DeleteShapePass(ir::Graph* graph) const { ...@@ -106,18 +119,22 @@ int DeleteRepeatedOpsPass::DeleteShapePass(ir::Graph* graph) const {
VLOG(4) << "handle DeleteShapePass"; VLOG(4) << "handle DeleteShapePass";
GET_IR_NODE_FROM_SUBGRAPH(in_var, in_var, pattern); GET_IR_NODE_FROM_SUBGRAPH(in_var, in_var, pattern);
std::vector<std::string> invalid_shape_out_ops{"while",
"conditional_block"};
std::vector<Node*> shapes; std::vector<Node*> shapes;
for (auto* next_op : in_var->outputs) { for (auto* next_op : in_var->outputs) {
if (next_op->Name() != "shape") continue; if (next_op->Name() != "shape") continue;
bool shape_out_has_control_flow_ops = false; bool shape_out_op_is_invalid = false;
for (auto* shape_out_op : next_op->outputs[0]->outputs) { for (auto* shape_out_op : next_op->outputs[0]->outputs) {
if (shape_out_op->Name() == "while" || if (std::count(invalid_shape_out_ops.begin(),
shape_out_op->Name() == "conditional_block") { invalid_shape_out_ops.end(),
shape_out_has_control_flow_ops = true; shape_out_op->Name()) > 0 ||
HasOutVarName(shape_out_op, next_op->outputs[0]->Name())) {
shape_out_op_is_invalid = true;
break; break;
} }
} }
if (!shape_out_has_control_flow_ops) { if (!shape_out_op_is_invalid) {
shapes.push_back(next_op); shapes.push_back(next_op);
} }
} }
...@@ -183,19 +200,23 @@ int DeleteRepeatedOpsPass::DeleteSlicePass(ir::Graph* graph) const { ...@@ -183,19 +200,23 @@ int DeleteRepeatedOpsPass::DeleteSlicePass(ir::Graph* graph) const {
VLOG(4) << "handle DeleteSlicePass"; VLOG(4) << "handle DeleteSlicePass";
GET_IR_NODE_FROM_SUBGRAPH(in_var, in_var, pattern); GET_IR_NODE_FROM_SUBGRAPH(in_var, in_var, pattern);
std::vector<std::string> invalid_slice_out_ops{"while",
"conditional_block"};
std::map<std::string, std::vector<Node*>> slice_ops; std::map<std::string, std::vector<Node*>> slice_ops;
for (auto* next_op : in_var->outputs) { for (auto* next_op : in_var->outputs) {
if (next_op->Name() != "slice") continue; if (next_op->Name() != "slice") continue;
auto* slice = next_op; auto* slice = next_op;
bool slice_out_has_control_flow_ops = false; bool slice_out_op_is_invalid = false;
for (auto* slice_out_op : slice->outputs[0]->outputs) { for (auto* slice_out_op : slice->outputs[0]->outputs) {
if (slice_out_op->Name() == "while" || if (std::count(invalid_slice_out_ops.begin(),
slice_out_op->Name() == "conditional_block") { invalid_slice_out_ops.end(),
slice_out_has_control_flow_ops = true; slice_out_op->Name()) > 0 ||
HasOutVarName(slice_out_op, slice->outputs[0]->Name())) {
slice_out_op_is_invalid = true;
break; break;
} }
} }
if (slice_out_has_control_flow_ops) continue; if (slice_out_op_is_invalid) continue;
auto attr_key = GenSliceAttrKey(slice->Op()); auto attr_key = GenSliceAttrKey(slice->Op());
slice_ops[attr_key].push_back(slice); slice_ops[attr_key].push_back(slice);
} }
...@@ -217,7 +238,8 @@ int DeleteRepeatedOpsPass::DeleteSlicePass(ir::Graph* graph) const { ...@@ -217,7 +238,8 @@ int DeleteRepeatedOpsPass::DeleteSlicePass(ir::Graph* graph) const {
auto* cur_slice_out = cur_slice->outputs[0]; auto* cur_slice_out = cur_slice->outputs[0];
auto cur_slice_out_name = cur_slice_out->Name(); auto cur_slice_out_name = cur_slice_out->Name();
for (auto* slice_out_op : cur_slice_out->outputs) { for (auto* slice_out_op : cur_slice_out->outputs) {
slice_out_op->Op()->Rename(cur_slice_out_name, first_slice_out_name); slice_out_op->Op()->RenameInput(cur_slice_out_name,
first_slice_out_name);
IR_NODE_LINK_TO(first_slice_out, slice_out_op); IR_NODE_LINK_TO(first_slice_out, slice_out_op);
} }
delete_nodes.insert(cur_slice); delete_nodes.insert(cur_slice);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册