提交 6efdea89 编写于 作者: N nhzlx

1. add shuffle_channel_detect

上级 8121b3ec
...@@ -70,6 +70,7 @@ pass_library(sync_batch_norm_pass base) ...@@ -70,6 +70,7 @@ pass_library(sync_batch_norm_pass base)
pass_library(runtime_context_cache_pass base) pass_library(runtime_context_cache_pass base)
pass_library(quant_conv2d_dequant_fuse_pass inference) pass_library(quant_conv2d_dequant_fuse_pass inference)
pass_library(fillconstant_elementwisemul_fuse inference) pass_library(fillconstant_elementwisemul_fuse inference)
pass_library(shuffle_channel_detect_pass inference)
if(ANAKIN_FOUND) if(ANAKIN_FOUND)
pass_library(simplify_anakin_priorbox_detection_out_pass inference) pass_library(simplify_anakin_priorbox_detection_out_pass inference)
......
...@@ -1706,6 +1706,37 @@ void patterns::QuantDequantOpFuse::operator()(PDNode *quant_op_input, ...@@ -1706,6 +1706,37 @@ void patterns::QuantDequantOpFuse::operator()(PDNode *quant_op_input,
} }
} }
void patterns::ShuffleChannelPattern::operator()(PDNode *reshape1_in) {
auto reshape1_op =
pattern->NewNode(reshape1_op_repr())->assert_is_op("reshape2");
auto reshape1_out = pattern->NewNode(reshape1_out_repr())
->assert_is_op_output("reshape2", "Out")
->assert_is_op_input("transpose2")
->AsIntermediate();
auto transpose_op =
pattern->NewNode(transpose_op_repr())->assert_is_op("transpose2");
auto transpose_out = pattern->NewNode(transpose_out_repr())
->assert_is_op_output("transpose2", "Out")
->assert_is_op_input("reshape2")
->AsIntermediate();
auto reshape2_op =
pattern->NewNode(reshape2_op_repr())->assert_is_op("reshape2");
auto reshape2_out = pattern->NewNode(reshape2_out_repr())
->assert_is_op_output("reshape2", "Out")
->AsOutput();
reshape1_op->LinksFrom({reshape1_in});
reshape1_out->LinksFrom({reshape1_op});
transpose_op->LinksFrom({reshape1_out});
transpose_out->LinksFrom({transpose_op});
reshape2_op->LinksFrom({transpose_out});
reshape2_out->LinksFrom({reshape2_op});
}
} // namespace ir } // namespace ir
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -892,6 +892,21 @@ struct QuantDequantOpFuse : public PatternBase { ...@@ -892,6 +892,21 @@ struct QuantDequantOpFuse : public PatternBase {
} }
}; };
struct ShuffleChannelPattern : public PatternBase {
ShuffleChannelPattern(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "shufflechannel_pattern") {}
void operator()(PDNode* reshape1_in);
PATTERN_DECL_NODE(reshape1_op);
PATTERN_DECL_NODE(reshape1_out);
PATTERN_DECL_NODE(transpose_op);
PATTERN_DECL_NODE(transpose_out);
PATTERN_DECL_NODE(reshape2_op);
PATTERN_DECL_NODE(reshape2_out);
};
} // namespace patterns } // namespace patterns
// Link two ir::Nodes from each other. // Link two ir::Nodes from each other.
......
...@@ -79,7 +79,11 @@ const std::vector<std::string> kAnakinSubgraphPasses({ ...@@ -79,7 +79,11 @@ const std::vector<std::string> kAnakinSubgraphPasses({
"fc_fuse_pass", // "fc_fuse_pass", //
"conv_elementwise_add_fuse_pass", // "conv_elementwise_add_fuse_pass", //
"fc_gru_fuse_pass", // "fc_gru_fuse_pass", //
"graph_viz_pass", //
"shuffle_channel_detect_pass", //
"graph_viz_pass", //
"anakin_subgraph_pass", // "anakin_subgraph_pass", //
"graph_viz_pass", //
"fc_gru_fuse_pass", // "fc_gru_fuse_pass", //
}); });
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册