From 6efdea8997cda4d737777f9be89cc9991120df64 Mon Sep 17 00:00:00 2001 From: nhzlx Date: Mon, 15 Apr 2019 15:13:02 +0000 Subject: [PATCH] 1. add shuffle_channel_detect --- paddle/fluid/framework/ir/CMakeLists.txt | 1 + .../framework/ir/graph_pattern_detector.cc | 31 +++++++++++++++++++ .../framework/ir/graph_pattern_detector.h | 15 +++++++++ .../inference/api/paddle_pass_builder.cc | 4 +++ 4 files changed, 51 insertions(+) diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index ba1d7379c56..a26732926c2 100644 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -70,6 +70,7 @@ pass_library(sync_batch_norm_pass base) pass_library(runtime_context_cache_pass base) pass_library(quant_conv2d_dequant_fuse_pass inference) pass_library(fillconstant_elementwisemul_fuse inference) +pass_library(shuffle_channel_detect_pass inference) if(ANAKIN_FOUND) pass_library(simplify_anakin_priorbox_detection_out_pass inference) diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.cc b/paddle/fluid/framework/ir/graph_pattern_detector.cc index 77f50e914b6..0dcf064902d 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.cc +++ b/paddle/fluid/framework/ir/graph_pattern_detector.cc @@ -1706,6 +1706,37 @@ void patterns::QuantDequantOpFuse::operator()(PDNode *quant_op_input, } } +void patterns::ShuffleChannelPattern::operator()(PDNode *reshape1_in) { + auto reshape1_op = + pattern->NewNode(reshape1_op_repr())->assert_is_op("reshape2"); + + auto reshape1_out = pattern->NewNode(reshape1_out_repr()) + ->assert_is_op_output("reshape2", "Out") + ->assert_is_op_input("transpose2") + ->AsIntermediate(); + + auto transpose_op = + pattern->NewNode(transpose_op_repr())->assert_is_op("transpose2"); + + auto transpose_out = pattern->NewNode(transpose_out_repr()) + ->assert_is_op_output("transpose2", "Out") + ->assert_is_op_input("reshape2") + ->AsIntermediate(); + + auto reshape2_op = + pattern->NewNode(reshape2_op_repr())->assert_is_op("reshape2"); + auto reshape2_out = pattern->NewNode(reshape2_out_repr()) + ->assert_is_op_output("reshape2", "Out") + ->AsOutput(); + + reshape1_op->LinksFrom({reshape1_in}); + reshape1_out->LinksFrom({reshape1_op}); + transpose_op->LinksFrom({reshape1_out}); + transpose_out->LinksFrom({transpose_op}); + reshape2_op->LinksFrom({transpose_out}); + reshape2_out->LinksFrom({reshape2_op}); +} + } // namespace ir } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.h b/paddle/fluid/framework/ir/graph_pattern_detector.h index 525987e0072..907371b56b0 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.h +++ b/paddle/fluid/framework/ir/graph_pattern_detector.h @@ -892,6 +892,21 @@ struct QuantDequantOpFuse : public PatternBase { } }; +struct ShuffleChannelPattern : public PatternBase { + ShuffleChannelPattern(PDPattern* pattern, const std::string& name_scope) + : PatternBase(pattern, name_scope, "shufflechannel_pattern") {} + + void operator()(PDNode* reshape1_in); + + PATTERN_DECL_NODE(reshape1_op); + PATTERN_DECL_NODE(reshape1_out); + + PATTERN_DECL_NODE(transpose_op); + PATTERN_DECL_NODE(transpose_out); + PATTERN_DECL_NODE(reshape2_op); + PATTERN_DECL_NODE(reshape2_out); +}; + } // namespace patterns // Link two ir::Nodes from each other. diff --git a/paddle/fluid/inference/api/paddle_pass_builder.cc b/paddle/fluid/inference/api/paddle_pass_builder.cc index fea291c5528..ab347b85885 100644 --- a/paddle/fluid/inference/api/paddle_pass_builder.cc +++ b/paddle/fluid/inference/api/paddle_pass_builder.cc @@ -79,7 +79,11 @@ const std::vector kAnakinSubgraphPasses({ "fc_fuse_pass", // "conv_elementwise_add_fuse_pass", // "fc_gru_fuse_pass", // + "graph_viz_pass", // + "shuffle_channel_detect_pass", // + "graph_viz_pass", // "anakin_subgraph_pass", // + "graph_viz_pass", // "fc_gru_fuse_pass", // }); -- GitLab