From 3297a5af3c431a36e8432a705164cf8d50386d0e Mon Sep 17 00:00:00 2001 From: zhupengyang <1165938320@qq.com> Date: Wed, 18 Mar 2020 14:50:07 +0800 Subject: [PATCH] Revert "enhance shuffle-channel fuse pass (#3208)" (#3214) This reverts commit 670fcc3c796605c262034af0216e7efa148975a9. --- .../mir/fusion/shuffle_channel_fuse_pass.cc | 14 +- lite/core/mir/fusion/shuffle_channel_fuser.cc | 163 ++++++------------ lite/core/mir/fusion/shuffle_channel_fuser.h | 13 +- 3 files changed, 60 insertions(+), 130 deletions(-) diff --git a/lite/core/mir/fusion/shuffle_channel_fuse_pass.cc b/lite/core/mir/fusion/shuffle_channel_fuse_pass.cc index 579c7559d1..2c289da82c 100644 --- a/lite/core/mir/fusion/shuffle_channel_fuse_pass.cc +++ b/lite/core/mir/fusion/shuffle_channel_fuse_pass.cc @@ -23,15 +23,11 @@ namespace lite { namespace mir { void ShuffleChannelFusePass::Apply(const std::unique_ptr& graph) { - for (std::string reshape_type : {"reshape", "reshape2"}) { - for (std::string transpose_type : {"transpose", "transpose2"}) { - for (std::string sub_structure : {"r_t_r", "s_t_r"}) { - fusion::ShuffleChannelFuser fuser( - reshape_type, transpose_type, sub_structure); - fuser(graph.get()); - } - } - } + fusion::ShuffleChannelFuser fuser("reshape", "transpose"); + fuser(graph.get()); + + fusion::ShuffleChannelFuser fuser2("reshape2", "transpose2"); + fuser2(graph.get()); } } // namespace mir diff --git a/lite/core/mir/fusion/shuffle_channel_fuser.cc b/lite/core/mir/fusion/shuffle_channel_fuser.cc index 7531d634eb..f0087f8991 100644 --- a/lite/core/mir/fusion/shuffle_channel_fuser.cc +++ b/lite/core/mir/fusion/shuffle_channel_fuser.cc @@ -22,107 +22,56 @@ namespace mir { namespace fusion { void ShuffleChannelFuser::BuildPattern() { - if (sub_structure_ == "r_t_r") { - // create nodes. - auto* x1 = VarNode("x1")->assert_is_op_input(reshape_type_, "X"); - auto* y1 = VarNode("y1")->assert_is_op_output(reshape_type_, "Out"); - auto* y2 = VarNode("y2")->assert_is_op_output(transpose_type_, "Out"); - auto* out = VarNode("out")->assert_is_op_output(reshape_type_, "Out"); - - PMNode* xshape1 = nullptr; - PMNode* xshape2 = nullptr; - PMNode* xshape3 = nullptr; - if (reshape_type_ == "reshape2") { - xshape1 = - VarNode("xshape1")->assert_is_op_output(reshape_type_, "XShape"); - xshape3 = - VarNode("xshape3")->assert_is_op_output(reshape_type_, "XShape"); - } - if (transpose_type_ == "transpose2") { - xshape2 = - VarNode("xshape2")->assert_is_op_output(transpose_type_, "XShape"); - } - - auto* reshape1 = OpNode("reshape1", reshape_type_) - ->assert_op_attr_satisfied>( - "shape", [](const std::vector& attr) { - return attr.size() >= 5 && attr[1] > 0; - }); - auto* transpose = - OpNode("transpose_op", transpose_type_) - ->assert_op_attr_satisfied>( - "axis", [](const std::vector& attr) { - return attr.size() >= 5 && attr[1] == 2 && attr[2] == 1; - }); - auto* reshape2 = OpNode("reshape2", reshape_type_) - ->assert_op_attr_satisfied>( - "shape", [](const std::vector& attr) { - return attr.size() >= 4; - }); - - // create topology. - *x1 >> *reshape1 >> *y1 >> *transpose >> *y2 >> *reshape2 >> *out; - if (xshape1) *reshape1 >> *xshape1; - if (xshape2) *transpose >> *xshape2; - if (xshape3) *reshape2 >> *xshape3; - - // Some op specialities. - y1->AsIntermediate(); - y2->AsIntermediate(); - if (xshape1) xshape1->AsIntermediate(); - if (xshape2) xshape2->AsIntermediate(); - if (xshape3) xshape3->AsIntermediate(); - reshape1->AsIntermediate(); - transpose->AsIntermediate(); - reshape2->AsIntermediate(); + // create nodes. + auto* x1 = VarNode("x1")->assert_is_op_input(reshape_type_, "X"); + auto* y1 = VarNode("y1")->assert_is_op_output(reshape_type_, "Out"); + auto* y2 = VarNode("y2")->assert_is_op_output(transpose_type_, "Out"); + auto* out = VarNode("out")->assert_is_op_output(reshape_type_, "Out"); + + PMNode* xshape1 = nullptr; + PMNode* xshape2 = nullptr; + PMNode* xshape3 = nullptr; + if (reshape_type_ == "reshape2") { + xshape1 = VarNode("xshape1")->assert_is_op_output(reshape_type_, "XShape"); + xshape3 = VarNode("xshape3")->assert_is_op_output(reshape_type_, "XShape"); } - - if (sub_structure_ == "s_t_r") { - // create nodes. - auto* x1 = VarNode("x1") - ->assert_is_op_input(transpose_type_, "X") - ->assert_is_op_output("stack", "Y"); - auto* y1 = VarNode("y1")->assert_is_op_output(transpose_type_, "Out"); - auto* out = VarNode("out")->assert_is_op_output(reshape_type_, "Out"); - - PMNode* xshape1 = nullptr; - PMNode* xshape2 = nullptr; - if (transpose_type_ == "transpose2") { - xshape1 = - VarNode("xshape2")->assert_is_op_output(transpose_type_, "XShape"); - } - if (reshape_type_ == "reshape2") { - xshape2 = - VarNode("xshape1")->assert_is_op_output(reshape_type_, "XShape"); - } - - auto* stack = OpNode("stack_op", "stack") - ->assert_op_attr_satisfied( - "axis", [](const int& attr) { return attr == 1; }); - auto* transpose = - OpNode("transpose_op", transpose_type_) - ->assert_op_attr_satisfied>( - "axis", [](const std::vector& attr) { - return attr.size() >= 5 && attr[1] == 2 && attr[2] == 1; - }); - auto* reshape = OpNode("reshape_op", reshape_type_) - ->assert_op_attr_satisfied>( - "shape", [](const std::vector& attr) { - return attr.size() >= 4; - }); - - // create topology. - *stack >> *x1 >> *transpose >> *y1 >> *reshape >> *out; - if (xshape1) *transpose >> *xshape1; - if (xshape2) *reshape >> *xshape2; - - // Some op specialities. - y1->AsIntermediate(); - if (xshape1) xshape1->AsIntermediate(); - if (xshape2) xshape2->AsIntermediate(); - transpose->AsIntermediate(); - reshape->AsIntermediate(); + if (transpose_type_ == "transpose2") { + xshape2 = + VarNode("xshape2")->assert_is_op_output(transpose_type_, "XShape"); } + + auto* reshape1 = OpNode("reshape1", reshape_type_) + ->assert_op_attr_satisfied>( + "shape", [](const std::vector& attr) { + return attr.size() >= 5 && attr[1] > 0; + }); + auto* transpose = + OpNode("transpose_op", transpose_type_) + ->assert_op_attr_satisfied>( + "axis", [](const std::vector& attr) { + return attr.size() >= 5 && attr[1] == 2 && attr[2] == 1; + }); + auto* reshape2 = OpNode("reshape2", reshape_type_) + ->assert_op_attr_satisfied>( + "shape", [](const std::vector& attr) { + return attr.size() >= 4; + }); + + // create topology. + *x1 >> *reshape1 >> *y1 >> *transpose >> *y2 >> *reshape2 >> *out; + if (xshape1) *reshape1 >> *xshape1; + if (xshape2) *transpose >> *xshape2; + if (xshape3) *reshape2 >> *xshape3; + + // Some op specialities. + y1->AsIntermediate(); + y2->AsIntermediate(); + if (xshape1) xshape1->AsIntermediate(); + if (xshape2) xshape2->AsIntermediate(); + if (xshape3) xshape3->AsIntermediate(); + reshape1->AsIntermediate(); + transpose->AsIntermediate(); + reshape2->AsIntermediate(); } void ShuffleChannelFuser::InsertNewNode(SSAGraph* graph, @@ -146,17 +95,11 @@ cpp::OpDesc ShuffleChannelFuser::GenOpDesc(const key2nodes_t& matched) { op_desc.SetType("shuffle_channel"); op_desc.SetInput("X", {matched.at("x1")->arg()->name}); op_desc.SetOutput("Out", {matched.at("out")->arg()->name}); - int group = 1; - if (sub_structure_ == "r_t_r") { - group = matched.at("reshape1") - ->stmt() - ->op_info() - ->GetAttr>("shape")[1]; - } - if (sub_structure_ == "s_t_r") { - group = matched.at("stack_op")->inlinks.size(); - } - op_desc.SetAttr("group", group); + op_desc.SetAttr("group", + matched.at("reshape1") + ->stmt() + ->op_info() + ->GetAttr>("shape")[1]); return op_desc; } diff --git a/lite/core/mir/fusion/shuffle_channel_fuser.h b/lite/core/mir/fusion/shuffle_channel_fuser.h index 46d8b1c47f..4fb99ab5c8 100644 --- a/lite/core/mir/fusion/shuffle_channel_fuser.h +++ b/lite/core/mir/fusion/shuffle_channel_fuser.h @@ -26,25 +26,16 @@ namespace fusion { class ShuffleChannelFuser : public FuseBase { public: explicit ShuffleChannelFuser(const std::string& reshape_type, - const std::string& transpose_type, - const std::string& sub_structure) - : reshape_type_(reshape_type), - transpose_type_(transpose_type), - sub_structure_(sub_structure) {} + const std::string& transpose_type) + : reshape_type_(reshape_type), transpose_type_(transpose_type) {} void BuildPattern() override; void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override; private: cpp::OpDesc GenOpDesc(const key2nodes_t& matched) override; - // reshape or reshape2 std::string reshape_type_; - // transpose or transpose2 std::string transpose_type_; - // r_t_r or t_r - // r_t_r: reshape + transpose + reshape - // s_t_r: stack + transpose + reshape - std::string sub_structure_; }; } // namespace fusion -- GitLab