diff --git a/lite/core/mir/fusion/shuffle_channel_fuse_pass.cc b/lite/core/mir/fusion/shuffle_channel_fuse_pass.cc index 579c7559d13ee5a4794ce1be5e5faac50d3cc0e5..2c289da82c69e9abac8cbc32a2efab47ebc05336 100644 --- a/lite/core/mir/fusion/shuffle_channel_fuse_pass.cc +++ b/lite/core/mir/fusion/shuffle_channel_fuse_pass.cc @@ -23,15 +23,11 @@ namespace lite { namespace mir { void ShuffleChannelFusePass::Apply(const std::unique_ptr& graph) { - for (std::string reshape_type : {"reshape", "reshape2"}) { - for (std::string transpose_type : {"transpose", "transpose2"}) { - for (std::string sub_structure : {"r_t_r", "s_t_r"}) { - fusion::ShuffleChannelFuser fuser( - reshape_type, transpose_type, sub_structure); - fuser(graph.get()); - } - } - } + fusion::ShuffleChannelFuser fuser("reshape", "transpose"); + fuser(graph.get()); + + fusion::ShuffleChannelFuser fuser2("reshape2", "transpose2"); + fuser2(graph.get()); } } // namespace mir diff --git a/lite/core/mir/fusion/shuffle_channel_fuser.cc b/lite/core/mir/fusion/shuffle_channel_fuser.cc index 7531d634eb97108ae9b3a48b6d00bc00e22f5955..f0087f8991b6b4457da29db0feac30c6bf9e722e 100644 --- a/lite/core/mir/fusion/shuffle_channel_fuser.cc +++ b/lite/core/mir/fusion/shuffle_channel_fuser.cc @@ -22,107 +22,56 @@ namespace mir { namespace fusion { void ShuffleChannelFuser::BuildPattern() { - if (sub_structure_ == "r_t_r") { - // create nodes. - auto* x1 = VarNode("x1")->assert_is_op_input(reshape_type_, "X"); - auto* y1 = VarNode("y1")->assert_is_op_output(reshape_type_, "Out"); - auto* y2 = VarNode("y2")->assert_is_op_output(transpose_type_, "Out"); - auto* out = VarNode("out")->assert_is_op_output(reshape_type_, "Out"); - - PMNode* xshape1 = nullptr; - PMNode* xshape2 = nullptr; - PMNode* xshape3 = nullptr; - if (reshape_type_ == "reshape2") { - xshape1 = - VarNode("xshape1")->assert_is_op_output(reshape_type_, "XShape"); - xshape3 = - VarNode("xshape3")->assert_is_op_output(reshape_type_, "XShape"); - } - if (transpose_type_ == "transpose2") { - xshape2 = - VarNode("xshape2")->assert_is_op_output(transpose_type_, "XShape"); - } - - auto* reshape1 = OpNode("reshape1", reshape_type_) - ->assert_op_attr_satisfied>( - "shape", [](const std::vector& attr) { - return attr.size() >= 5 && attr[1] > 0; - }); - auto* transpose = - OpNode("transpose_op", transpose_type_) - ->assert_op_attr_satisfied>( - "axis", [](const std::vector& attr) { - return attr.size() >= 5 && attr[1] == 2 && attr[2] == 1; - }); - auto* reshape2 = OpNode("reshape2", reshape_type_) - ->assert_op_attr_satisfied>( - "shape", [](const std::vector& attr) { - return attr.size() >= 4; - }); - - // create topology. - *x1 >> *reshape1 >> *y1 >> *transpose >> *y2 >> *reshape2 >> *out; - if (xshape1) *reshape1 >> *xshape1; - if (xshape2) *transpose >> *xshape2; - if (xshape3) *reshape2 >> *xshape3; - - // Some op specialities. - y1->AsIntermediate(); - y2->AsIntermediate(); - if (xshape1) xshape1->AsIntermediate(); - if (xshape2) xshape2->AsIntermediate(); - if (xshape3) xshape3->AsIntermediate(); - reshape1->AsIntermediate(); - transpose->AsIntermediate(); - reshape2->AsIntermediate(); + // create nodes. + auto* x1 = VarNode("x1")->assert_is_op_input(reshape_type_, "X"); + auto* y1 = VarNode("y1")->assert_is_op_output(reshape_type_, "Out"); + auto* y2 = VarNode("y2")->assert_is_op_output(transpose_type_, "Out"); + auto* out = VarNode("out")->assert_is_op_output(reshape_type_, "Out"); + + PMNode* xshape1 = nullptr; + PMNode* xshape2 = nullptr; + PMNode* xshape3 = nullptr; + if (reshape_type_ == "reshape2") { + xshape1 = VarNode("xshape1")->assert_is_op_output(reshape_type_, "XShape"); + xshape3 = VarNode("xshape3")->assert_is_op_output(reshape_type_, "XShape"); } - - if (sub_structure_ == "s_t_r") { - // create nodes. - auto* x1 = VarNode("x1") - ->assert_is_op_input(transpose_type_, "X") - ->assert_is_op_output("stack", "Y"); - auto* y1 = VarNode("y1")->assert_is_op_output(transpose_type_, "Out"); - auto* out = VarNode("out")->assert_is_op_output(reshape_type_, "Out"); - - PMNode* xshape1 = nullptr; - PMNode* xshape2 = nullptr; - if (transpose_type_ == "transpose2") { - xshape1 = - VarNode("xshape2")->assert_is_op_output(transpose_type_, "XShape"); - } - if (reshape_type_ == "reshape2") { - xshape2 = - VarNode("xshape1")->assert_is_op_output(reshape_type_, "XShape"); - } - - auto* stack = OpNode("stack_op", "stack") - ->assert_op_attr_satisfied( - "axis", [](const int& attr) { return attr == 1; }); - auto* transpose = - OpNode("transpose_op", transpose_type_) - ->assert_op_attr_satisfied>( - "axis", [](const std::vector& attr) { - return attr.size() >= 5 && attr[1] == 2 && attr[2] == 1; - }); - auto* reshape = OpNode("reshape_op", reshape_type_) - ->assert_op_attr_satisfied>( - "shape", [](const std::vector& attr) { - return attr.size() >= 4; - }); - - // create topology. - *stack >> *x1 >> *transpose >> *y1 >> *reshape >> *out; - if (xshape1) *transpose >> *xshape1; - if (xshape2) *reshape >> *xshape2; - - // Some op specialities. - y1->AsIntermediate(); - if (xshape1) xshape1->AsIntermediate(); - if (xshape2) xshape2->AsIntermediate(); - transpose->AsIntermediate(); - reshape->AsIntermediate(); + if (transpose_type_ == "transpose2") { + xshape2 = + VarNode("xshape2")->assert_is_op_output(transpose_type_, "XShape"); } + + auto* reshape1 = OpNode("reshape1", reshape_type_) + ->assert_op_attr_satisfied>( + "shape", [](const std::vector& attr) { + return attr.size() >= 5 && attr[1] > 0; + }); + auto* transpose = + OpNode("transpose_op", transpose_type_) + ->assert_op_attr_satisfied>( + "axis", [](const std::vector& attr) { + return attr.size() >= 5 && attr[1] == 2 && attr[2] == 1; + }); + auto* reshape2 = OpNode("reshape2", reshape_type_) + ->assert_op_attr_satisfied>( + "shape", [](const std::vector& attr) { + return attr.size() >= 4; + }); + + // create topology. + *x1 >> *reshape1 >> *y1 >> *transpose >> *y2 >> *reshape2 >> *out; + if (xshape1) *reshape1 >> *xshape1; + if (xshape2) *transpose >> *xshape2; + if (xshape3) *reshape2 >> *xshape3; + + // Some op specialities. + y1->AsIntermediate(); + y2->AsIntermediate(); + if (xshape1) xshape1->AsIntermediate(); + if (xshape2) xshape2->AsIntermediate(); + if (xshape3) xshape3->AsIntermediate(); + reshape1->AsIntermediate(); + transpose->AsIntermediate(); + reshape2->AsIntermediate(); } void ShuffleChannelFuser::InsertNewNode(SSAGraph* graph, @@ -146,17 +95,11 @@ cpp::OpDesc ShuffleChannelFuser::GenOpDesc(const key2nodes_t& matched) { op_desc.SetType("shuffle_channel"); op_desc.SetInput("X", {matched.at("x1")->arg()->name}); op_desc.SetOutput("Out", {matched.at("out")->arg()->name}); - int group = 1; - if (sub_structure_ == "r_t_r") { - group = matched.at("reshape1") - ->stmt() - ->op_info() - ->GetAttr>("shape")[1]; - } - if (sub_structure_ == "s_t_r") { - group = matched.at("stack_op")->inlinks.size(); - } - op_desc.SetAttr("group", group); + op_desc.SetAttr("group", + matched.at("reshape1") + ->stmt() + ->op_info() + ->GetAttr>("shape")[1]); return op_desc; } diff --git a/lite/core/mir/fusion/shuffle_channel_fuser.h b/lite/core/mir/fusion/shuffle_channel_fuser.h index 46d8b1c47f74a5a8418f045ec625c07ac26c1fb1..4fb99ab5c85a9d166e65eca1050f800e8c9b1795 100644 --- a/lite/core/mir/fusion/shuffle_channel_fuser.h +++ b/lite/core/mir/fusion/shuffle_channel_fuser.h @@ -26,25 +26,16 @@ namespace fusion { class ShuffleChannelFuser : public FuseBase { public: explicit ShuffleChannelFuser(const std::string& reshape_type, - const std::string& transpose_type, - const std::string& sub_structure) - : reshape_type_(reshape_type), - transpose_type_(transpose_type), - sub_structure_(sub_structure) {} + const std::string& transpose_type) + : reshape_type_(reshape_type), transpose_type_(transpose_type) {} void BuildPattern() override; void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override; private: cpp::OpDesc GenOpDesc(const key2nodes_t& matched) override; - // reshape or reshape2 std::string reshape_type_; - // transpose or transpose2 std::string transpose_type_; - // r_t_r or t_r - // r_t_r: reshape + transpose + reshape - // s_t_r: stack + transpose + reshape - std::string sub_structure_; }; } // namespace fusion