From 1af0e6c38e7a7886eddfbd21b579fc1bcc01157a Mon Sep 17 00:00:00 2001 From: zhupengyang <1165938320@qq.com> Date: Tue, 17 Mar 2020 21:50:14 +0800 Subject: [PATCH] enhance shuffle-channel fuse pass (#3208) --- .../mir/fusion/shuffle_channel_fuse_pass.cc | 14 +- lite/core/mir/fusion/shuffle_channel_fuser.cc | 163 ++++++++++++------ lite/core/mir/fusion/shuffle_channel_fuser.h | 13 +- 3 files changed, 130 insertions(+), 60 deletions(-) diff --git a/lite/core/mir/fusion/shuffle_channel_fuse_pass.cc b/lite/core/mir/fusion/shuffle_channel_fuse_pass.cc index 2c289da82c..579c7559d1 100644 --- a/lite/core/mir/fusion/shuffle_channel_fuse_pass.cc +++ b/lite/core/mir/fusion/shuffle_channel_fuse_pass.cc @@ -23,11 +23,15 @@ namespace lite { namespace mir { void ShuffleChannelFusePass::Apply(const std::unique_ptr& graph) { - fusion::ShuffleChannelFuser fuser("reshape", "transpose"); - fuser(graph.get()); - - fusion::ShuffleChannelFuser fuser2("reshape2", "transpose2"); - fuser2(graph.get()); + for (std::string reshape_type : {"reshape", "reshape2"}) { + for (std::string transpose_type : {"transpose", "transpose2"}) { + for (std::string sub_structure : {"r_t_r", "s_t_r"}) { + fusion::ShuffleChannelFuser fuser( + reshape_type, transpose_type, sub_structure); + fuser(graph.get()); + } + } + } } } // namespace mir diff --git a/lite/core/mir/fusion/shuffle_channel_fuser.cc b/lite/core/mir/fusion/shuffle_channel_fuser.cc index f0087f8991..7531d634eb 100644 --- a/lite/core/mir/fusion/shuffle_channel_fuser.cc +++ b/lite/core/mir/fusion/shuffle_channel_fuser.cc @@ -22,56 +22,107 @@ namespace mir { namespace fusion { void ShuffleChannelFuser::BuildPattern() { - // create nodes. - auto* x1 = VarNode("x1")->assert_is_op_input(reshape_type_, "X"); - auto* y1 = VarNode("y1")->assert_is_op_output(reshape_type_, "Out"); - auto* y2 = VarNode("y2")->assert_is_op_output(transpose_type_, "Out"); - auto* out = VarNode("out")->assert_is_op_output(reshape_type_, "Out"); - - PMNode* xshape1 = nullptr; - PMNode* xshape2 = nullptr; - PMNode* xshape3 = nullptr; - if (reshape_type_ == "reshape2") { - xshape1 = VarNode("xshape1")->assert_is_op_output(reshape_type_, "XShape"); - xshape3 = VarNode("xshape3")->assert_is_op_output(reshape_type_, "XShape"); - } - if (transpose_type_ == "transpose2") { - xshape2 = - VarNode("xshape2")->assert_is_op_output(transpose_type_, "XShape"); + if (sub_structure_ == "r_t_r") { + // create nodes. + auto* x1 = VarNode("x1")->assert_is_op_input(reshape_type_, "X"); + auto* y1 = VarNode("y1")->assert_is_op_output(reshape_type_, "Out"); + auto* y2 = VarNode("y2")->assert_is_op_output(transpose_type_, "Out"); + auto* out = VarNode("out")->assert_is_op_output(reshape_type_, "Out"); + + PMNode* xshape1 = nullptr; + PMNode* xshape2 = nullptr; + PMNode* xshape3 = nullptr; + if (reshape_type_ == "reshape2") { + xshape1 = + VarNode("xshape1")->assert_is_op_output(reshape_type_, "XShape"); + xshape3 = + VarNode("xshape3")->assert_is_op_output(reshape_type_, "XShape"); + } + if (transpose_type_ == "transpose2") { + xshape2 = + VarNode("xshape2")->assert_is_op_output(transpose_type_, "XShape"); + } + + auto* reshape1 = OpNode("reshape1", reshape_type_) + ->assert_op_attr_satisfied>( + "shape", [](const std::vector& attr) { + return attr.size() >= 5 && attr[1] > 0; + }); + auto* transpose = + OpNode("transpose_op", transpose_type_) + ->assert_op_attr_satisfied>( + "axis", [](const std::vector& attr) { + return attr.size() >= 5 && attr[1] == 2 && attr[2] == 1; + }); + auto* reshape2 = OpNode("reshape2", reshape_type_) + ->assert_op_attr_satisfied>( + "shape", [](const std::vector& attr) { + return attr.size() >= 4; + }); + + // create topology. + *x1 >> *reshape1 >> *y1 >> *transpose >> *y2 >> *reshape2 >> *out; + if (xshape1) *reshape1 >> *xshape1; + if (xshape2) *transpose >> *xshape2; + if (xshape3) *reshape2 >> *xshape3; + + // Some op specialities. + y1->AsIntermediate(); + y2->AsIntermediate(); + if (xshape1) xshape1->AsIntermediate(); + if (xshape2) xshape2->AsIntermediate(); + if (xshape3) xshape3->AsIntermediate(); + reshape1->AsIntermediate(); + transpose->AsIntermediate(); + reshape2->AsIntermediate(); } - auto* reshape1 = OpNode("reshape1", reshape_type_) - ->assert_op_attr_satisfied>( - "shape", [](const std::vector& attr) { - return attr.size() >= 5 && attr[1] > 0; - }); - auto* transpose = - OpNode("transpose_op", transpose_type_) - ->assert_op_attr_satisfied>( - "axis", [](const std::vector& attr) { - return attr.size() >= 5 && attr[1] == 2 && attr[2] == 1; - }); - auto* reshape2 = OpNode("reshape2", reshape_type_) - ->assert_op_attr_satisfied>( - "shape", [](const std::vector& attr) { - return attr.size() >= 4; - }); - - // create topology. - *x1 >> *reshape1 >> *y1 >> *transpose >> *y2 >> *reshape2 >> *out; - if (xshape1) *reshape1 >> *xshape1; - if (xshape2) *transpose >> *xshape2; - if (xshape3) *reshape2 >> *xshape3; - - // Some op specialities. - y1->AsIntermediate(); - y2->AsIntermediate(); - if (xshape1) xshape1->AsIntermediate(); - if (xshape2) xshape2->AsIntermediate(); - if (xshape3) xshape3->AsIntermediate(); - reshape1->AsIntermediate(); - transpose->AsIntermediate(); - reshape2->AsIntermediate(); + if (sub_structure_ == "s_t_r") { + // create nodes. + auto* x1 = VarNode("x1") + ->assert_is_op_input(transpose_type_, "X") + ->assert_is_op_output("stack", "Y"); + auto* y1 = VarNode("y1")->assert_is_op_output(transpose_type_, "Out"); + auto* out = VarNode("out")->assert_is_op_output(reshape_type_, "Out"); + + PMNode* xshape1 = nullptr; + PMNode* xshape2 = nullptr; + if (transpose_type_ == "transpose2") { + xshape1 = + VarNode("xshape2")->assert_is_op_output(transpose_type_, "XShape"); + } + if (reshape_type_ == "reshape2") { + xshape2 = + VarNode("xshape1")->assert_is_op_output(reshape_type_, "XShape"); + } + + auto* stack = OpNode("stack_op", "stack") + ->assert_op_attr_satisfied( + "axis", [](const int& attr) { return attr == 1; }); + auto* transpose = + OpNode("transpose_op", transpose_type_) + ->assert_op_attr_satisfied>( + "axis", [](const std::vector& attr) { + return attr.size() >= 5 && attr[1] == 2 && attr[2] == 1; + }); + auto* reshape = OpNode("reshape_op", reshape_type_) + ->assert_op_attr_satisfied>( + "shape", [](const std::vector& attr) { + return attr.size() >= 4; + }); + + // create topology. + *stack >> *x1 >> *transpose >> *y1 >> *reshape >> *out; + if (xshape1) *transpose >> *xshape1; + if (xshape2) *reshape >> *xshape2; + + // Some op specialities. + y1->AsIntermediate(); + if (xshape1) xshape1->AsIntermediate(); + if (xshape2) xshape2->AsIntermediate(); + transpose->AsIntermediate(); + reshape->AsIntermediate(); + } } void ShuffleChannelFuser::InsertNewNode(SSAGraph* graph, @@ -95,11 +146,17 @@ cpp::OpDesc ShuffleChannelFuser::GenOpDesc(const key2nodes_t& matched) { op_desc.SetType("shuffle_channel"); op_desc.SetInput("X", {matched.at("x1")->arg()->name}); op_desc.SetOutput("Out", {matched.at("out")->arg()->name}); - op_desc.SetAttr("group", - matched.at("reshape1") - ->stmt() - ->op_info() - ->GetAttr>("shape")[1]); + int group = 1; + if (sub_structure_ == "r_t_r") { + group = matched.at("reshape1") + ->stmt() + ->op_info() + ->GetAttr>("shape")[1]; + } + if (sub_structure_ == "s_t_r") { + group = matched.at("stack_op")->inlinks.size(); + } + op_desc.SetAttr("group", group); return op_desc; } diff --git a/lite/core/mir/fusion/shuffle_channel_fuser.h b/lite/core/mir/fusion/shuffle_channel_fuser.h index 4fb99ab5c8..46d8b1c47f 100644 --- a/lite/core/mir/fusion/shuffle_channel_fuser.h +++ b/lite/core/mir/fusion/shuffle_channel_fuser.h @@ -26,16 +26,25 @@ namespace fusion { class ShuffleChannelFuser : public FuseBase { public: explicit ShuffleChannelFuser(const std::string& reshape_type, - const std::string& transpose_type) - : reshape_type_(reshape_type), transpose_type_(transpose_type) {} + const std::string& transpose_type, + const std::string& sub_structure) + : reshape_type_(reshape_type), + transpose_type_(transpose_type), + sub_structure_(sub_structure) {} void BuildPattern() override; void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override; private: cpp::OpDesc GenOpDesc(const key2nodes_t& matched) override; + // reshape or reshape2 std::string reshape_type_; + // transpose or transpose2 std::string transpose_type_; + // r_t_r or t_r + // r_t_r: reshape + transpose + reshape + // s_t_r: stack + transpose + reshape + std::string sub_structure_; }; } // namespace fusion -- GitLab