diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index ba1d7379c56d953a0f37d03deed6c47e46cbf129..a26732926c2c6e376079c610078c80c6a8afd452 100644 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -70,6 +70,7 @@ pass_library(sync_batch_norm_pass base) pass_library(runtime_context_cache_pass base) pass_library(quant_conv2d_dequant_fuse_pass inference) pass_library(fillconstant_elementwisemul_fuse inference) +pass_library(shuffle_channel_detect_pass inference) if(ANAKIN_FOUND) pass_library(simplify_anakin_priorbox_detection_out_pass inference) diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.cc b/paddle/fluid/framework/ir/graph_pattern_detector.cc index 77f50e914b668ebfeb2fcaf5de8f91a74f0c0d3b..0dcf064902d1c1c6cb034421cedea0387b6e0505 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.cc +++ b/paddle/fluid/framework/ir/graph_pattern_detector.cc @@ -1706,6 +1706,37 @@ void patterns::QuantDequantOpFuse::operator()(PDNode *quant_op_input, } } +void patterns::ShuffleChannelPattern::operator()(PDNode *reshape1_in) { + auto reshape1_op = + pattern->NewNode(reshape1_op_repr())->assert_is_op("reshape2"); + + auto reshape1_out = pattern->NewNode(reshape1_out_repr()) + ->assert_is_op_output("reshape2", "Out") + ->assert_is_op_input("transpose2") + ->AsIntermediate(); + + auto transpose_op = + pattern->NewNode(transpose_op_repr())->assert_is_op("transpose2"); + + auto transpose_out = pattern->NewNode(transpose_out_repr()) + ->assert_is_op_output("transpose2", "Out") + ->assert_is_op_input("reshape2") + ->AsIntermediate(); + + auto reshape2_op = + pattern->NewNode(reshape2_op_repr())->assert_is_op("reshape2"); + auto reshape2_out = pattern->NewNode(reshape2_out_repr()) + ->assert_is_op_output("reshape2", "Out") + ->AsOutput(); + + reshape1_op->LinksFrom({reshape1_in}); + reshape1_out->LinksFrom({reshape1_op}); + transpose_op->LinksFrom({reshape1_out}); + transpose_out->LinksFrom({transpose_op}); + reshape2_op->LinksFrom({transpose_out}); + reshape2_out->LinksFrom({reshape2_op}); +} + } // namespace ir } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.h b/paddle/fluid/framework/ir/graph_pattern_detector.h index 525987e0072cb05ad3df4d09a17ac172e48ce133..907371b56b06dcd66297adedea6c17b61d9b5e38 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.h +++ b/paddle/fluid/framework/ir/graph_pattern_detector.h @@ -892,6 +892,21 @@ struct QuantDequantOpFuse : public PatternBase { } }; +struct ShuffleChannelPattern : public PatternBase { + ShuffleChannelPattern(PDPattern* pattern, const std::string& name_scope) + : PatternBase(pattern, name_scope, "shufflechannel_pattern") {} + + void operator()(PDNode* reshape1_in); + + PATTERN_DECL_NODE(reshape1_op); + PATTERN_DECL_NODE(reshape1_out); + + PATTERN_DECL_NODE(transpose_op); + PATTERN_DECL_NODE(transpose_out); + PATTERN_DECL_NODE(reshape2_op); + PATTERN_DECL_NODE(reshape2_out); +}; + } // namespace patterns // Link two ir::Nodes from each other. diff --git a/paddle/fluid/inference/api/paddle_pass_builder.cc b/paddle/fluid/inference/api/paddle_pass_builder.cc index fea291c5528a11fd18b1069a5d57e456c8cc84fc..ab347b85885fe3bc6a46c60e572e3a03185b5f44 100644 --- a/paddle/fluid/inference/api/paddle_pass_builder.cc +++ b/paddle/fluid/inference/api/paddle_pass_builder.cc @@ -79,7 +79,11 @@ const std::vector kAnakinSubgraphPasses({ "fc_fuse_pass", // "conv_elementwise_add_fuse_pass", // "fc_gru_fuse_pass", // + "graph_viz_pass", // + "shuffle_channel_detect_pass", // + "graph_viz_pass", // "anakin_subgraph_pass", // + "graph_viz_pass", // "fc_gru_fuse_pass", // });