未验证 提交 54f07019 编写于 作者: 王明冬 提交者: GitHub

fix the pass compat check position error, test=develop (#35272)

上级 ef536250
......@@ -264,10 +264,6 @@ void Squeeze2MatmulFusePass::ApplyImpl(ir::Graph* graph) const {
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
VLOG(4) << "fuse squeeze2+matmul to mul";
if (!IsCompat(subgraph, g)) {
LOG(WARNING) << "Pass in op compat failed.";
return;
}
GET_IR_NODE_FROM_SUBGRAPH(squeeze2_in_x, squeeze2_in_x, fuse_pattern);
GET_IR_NODE_FROM_SUBGRAPH(squeeze2_op, squeeze2_op, fuse_pattern);
GET_IR_NODE_FROM_SUBGRAPH(matmul_in_x, matmul_in_x, fuse_pattern);
......@@ -299,6 +295,10 @@ void Squeeze2MatmulFusePass::ApplyImpl(ir::Graph* graph) const {
next_ops[0]->Name() == "elementwise_add";
if (flag) {
if (!IsCompat(subgraph, g)) {
LOG(WARNING) << "Pass in op compat failed.";
return;
}
OpDesc desc;
desc.SetType("mul");
desc.SetInput("X", {squeeze2_in_x->Name()});
......@@ -403,10 +403,6 @@ void Reshape2MatmulFusePass::ApplyImpl(ir::Graph* graph) const {
int found_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
if (!IsCompat(subgraph, g)) {
LOG(WARNING) << "Pass in op compat failed.";
return;
}
VLOG(4) << "fuse reshape2+matmul to mul";
GET_IR_NODE_FROM_SUBGRAPH(reshape2_in_x, reshape2_in_x, fuse_pattern);
GET_IR_NODE_FROM_SUBGRAPH(reshape2_op, reshape2_op, fuse_pattern);
......@@ -441,6 +437,10 @@ void Reshape2MatmulFusePass::ApplyImpl(ir::Graph* graph) const {
next_ops[0]->Name() == "elementwise_add";
if (flag) {
if (!IsCompat(subgraph, g)) {
LOG(WARNING) << "Pass in op compat failed.";
return;
}
OpDesc desc;
desc.SetType("mul");
desc.SetInput("X", {reshape2_in_x->Name()});
......@@ -483,11 +483,6 @@ void Flatten2MatmulFusePass::ApplyImpl(ir::Graph* graph) const {
int found_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
if (!IsCompat(subgraph, g)) {
LOG(WARNING) << "Pass in op compat failed.";
return;
}
VLOG(4) << "fuse flatten2+matmul to mul";
GET_IR_NODE_FROM_SUBGRAPH(flatten2_in_x, flatten2_in_x, fuse_pattern);
GET_IR_NODE_FROM_SUBGRAPH(flatten2_op, flatten2_op, fuse_pattern);
......@@ -527,6 +522,10 @@ void Flatten2MatmulFusePass::ApplyImpl(ir::Graph* graph) const {
next_ops[0]->Name() == "elementwise_add";
if (pattern_found) {
if (!IsCompat(subgraph, g)) {
LOG(WARNING) << "Pass in op compat failed.";
return;
}
OpDesc desc;
desc.SetType("mul");
desc.SetInput("X", {flatten2_in_x->Name()});
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册