未验证 提交 2b546613 编写于 作者: Z zhupengyang 提交者: GitHub

refine delete_repeated_ops_pass (#54122)

* Not delete slice op if out op has shared var node
上级 814f1d65
......@@ -32,6 +32,19 @@ class Scope;
namespace paddle {
namespace framework {
namespace ir {
bool HasOutVarName(Node* op_node, std::string name) {
auto* op_desc = op_node->Op();
auto outputs = op_desc->Outputs();
for (auto iter : outputs) {
auto out_names = iter.second;
if (std::count(out_names.begin(), out_names.end(), name) > 0) {
return true;
}
}
return false;
}
namespace patterns {
struct VarWithRepeatedOpsPattern : public PatternBase {
......@@ -106,18 +119,22 @@ int DeleteRepeatedOpsPass::DeleteShapePass(ir::Graph* graph) const {
VLOG(4) << "handle DeleteShapePass";
GET_IR_NODE_FROM_SUBGRAPH(in_var, in_var, pattern);
std::vector<std::string> invalid_shape_out_ops{"while",
"conditional_block"};
std::vector<Node*> shapes;
for (auto* next_op : in_var->outputs) {
if (next_op->Name() != "shape") continue;
bool shape_out_has_control_flow_ops = false;
bool shape_out_op_is_invalid = false;
for (auto* shape_out_op : next_op->outputs[0]->outputs) {
if (shape_out_op->Name() == "while" ||
shape_out_op->Name() == "conditional_block") {
shape_out_has_control_flow_ops = true;
if (std::count(invalid_shape_out_ops.begin(),
invalid_shape_out_ops.end(),
shape_out_op->Name()) > 0 ||
HasOutVarName(shape_out_op, next_op->outputs[0]->Name())) {
shape_out_op_is_invalid = true;
break;
}
}
if (!shape_out_has_control_flow_ops) {
if (!shape_out_op_is_invalid) {
shapes.push_back(next_op);
}
}
......@@ -183,19 +200,23 @@ int DeleteRepeatedOpsPass::DeleteSlicePass(ir::Graph* graph) const {
VLOG(4) << "handle DeleteSlicePass";
GET_IR_NODE_FROM_SUBGRAPH(in_var, in_var, pattern);
std::vector<std::string> invalid_slice_out_ops{"while",
"conditional_block"};
std::map<std::string, std::vector<Node*>> slice_ops;
for (auto* next_op : in_var->outputs) {
if (next_op->Name() != "slice") continue;
auto* slice = next_op;
bool slice_out_has_control_flow_ops = false;
bool slice_out_op_is_invalid = false;
for (auto* slice_out_op : slice->outputs[0]->outputs) {
if (slice_out_op->Name() == "while" ||
slice_out_op->Name() == "conditional_block") {
slice_out_has_control_flow_ops = true;
if (std::count(invalid_slice_out_ops.begin(),
invalid_slice_out_ops.end(),
slice_out_op->Name()) > 0 ||
HasOutVarName(slice_out_op, slice->outputs[0]->Name())) {
slice_out_op_is_invalid = true;
break;
}
}
if (slice_out_has_control_flow_ops) continue;
if (slice_out_op_is_invalid) continue;
auto attr_key = GenSliceAttrKey(slice->Op());
slice_ops[attr_key].push_back(slice);
}
......@@ -217,7 +238,8 @@ int DeleteRepeatedOpsPass::DeleteSlicePass(ir::Graph* graph) const {
auto* cur_slice_out = cur_slice->outputs[0];
auto cur_slice_out_name = cur_slice_out->Name();
for (auto* slice_out_op : cur_slice_out->outputs) {
slice_out_op->Op()->Rename(cur_slice_out_name, first_slice_out_name);
slice_out_op->Op()->RenameInput(cur_slice_out_name,
first_slice_out_name);
IR_NODE_LINK_TO(first_slice_out, slice_out_op);
}
delete_nodes.insert(cur_slice);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册