提交 1af0e6c3 编写于 作者: Z zhupengyang 提交者: GitHub

enhance shuffle-channel fuse pass (#3208)

上级 fcc6b2da
...@@ -23,11 +23,15 @@ namespace lite { ...@@ -23,11 +23,15 @@ namespace lite {
namespace mir { namespace mir {
void ShuffleChannelFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) { void ShuffleChannelFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) {
fusion::ShuffleChannelFuser fuser("reshape", "transpose"); for (std::string reshape_type : {"reshape", "reshape2"}) {
fuser(graph.get()); for (std::string transpose_type : {"transpose", "transpose2"}) {
for (std::string sub_structure : {"r_t_r", "s_t_r"}) {
fusion::ShuffleChannelFuser fuser2("reshape2", "transpose2"); fusion::ShuffleChannelFuser fuser(
fuser2(graph.get()); reshape_type, transpose_type, sub_structure);
fuser(graph.get());
}
}
}
} }
} // namespace mir } // namespace mir
......
...@@ -22,56 +22,107 @@ namespace mir { ...@@ -22,56 +22,107 @@ namespace mir {
namespace fusion { namespace fusion {
void ShuffleChannelFuser::BuildPattern() { void ShuffleChannelFuser::BuildPattern() {
// create nodes. if (sub_structure_ == "r_t_r") {
auto* x1 = VarNode("x1")->assert_is_op_input(reshape_type_, "X"); // create nodes.
auto* y1 = VarNode("y1")->assert_is_op_output(reshape_type_, "Out"); auto* x1 = VarNode("x1")->assert_is_op_input(reshape_type_, "X");
auto* y2 = VarNode("y2")->assert_is_op_output(transpose_type_, "Out"); auto* y1 = VarNode("y1")->assert_is_op_output(reshape_type_, "Out");
auto* out = VarNode("out")->assert_is_op_output(reshape_type_, "Out"); auto* y2 = VarNode("y2")->assert_is_op_output(transpose_type_, "Out");
auto* out = VarNode("out")->assert_is_op_output(reshape_type_, "Out");
PMNode* xshape1 = nullptr;
PMNode* xshape2 = nullptr; PMNode* xshape1 = nullptr;
PMNode* xshape3 = nullptr; PMNode* xshape2 = nullptr;
if (reshape_type_ == "reshape2") { PMNode* xshape3 = nullptr;
xshape1 = VarNode("xshape1")->assert_is_op_output(reshape_type_, "XShape"); if (reshape_type_ == "reshape2") {
xshape3 = VarNode("xshape3")->assert_is_op_output(reshape_type_, "XShape"); xshape1 =
} VarNode("xshape1")->assert_is_op_output(reshape_type_, "XShape");
if (transpose_type_ == "transpose2") { xshape3 =
xshape2 = VarNode("xshape3")->assert_is_op_output(reshape_type_, "XShape");
VarNode("xshape2")->assert_is_op_output(transpose_type_, "XShape"); }
if (transpose_type_ == "transpose2") {
xshape2 =
VarNode("xshape2")->assert_is_op_output(transpose_type_, "XShape");
}
auto* reshape1 = OpNode("reshape1", reshape_type_)
->assert_op_attr_satisfied<std::vector<int>>(
"shape", [](const std::vector<int>& attr) {
return attr.size() >= 5 && attr[1] > 0;
});
auto* transpose =
OpNode("transpose_op", transpose_type_)
->assert_op_attr_satisfied<std::vector<int>>(
"axis", [](const std::vector<int>& attr) {
return attr.size() >= 5 && attr[1] == 2 && attr[2] == 1;
});
auto* reshape2 = OpNode("reshape2", reshape_type_)
->assert_op_attr_satisfied<std::vector<int>>(
"shape", [](const std::vector<int>& attr) {
return attr.size() >= 4;
});
// create topology.
*x1 >> *reshape1 >> *y1 >> *transpose >> *y2 >> *reshape2 >> *out;
if (xshape1) *reshape1 >> *xshape1;
if (xshape2) *transpose >> *xshape2;
if (xshape3) *reshape2 >> *xshape3;
// Some op specialities.
y1->AsIntermediate();
y2->AsIntermediate();
if (xshape1) xshape1->AsIntermediate();
if (xshape2) xshape2->AsIntermediate();
if (xshape3) xshape3->AsIntermediate();
reshape1->AsIntermediate();
transpose->AsIntermediate();
reshape2->AsIntermediate();
} }
auto* reshape1 = OpNode("reshape1", reshape_type_) if (sub_structure_ == "s_t_r") {
->assert_op_attr_satisfied<std::vector<int>>( // create nodes.
"shape", [](const std::vector<int>& attr) { auto* x1 = VarNode("x1")
return attr.size() >= 5 && attr[1] > 0; ->assert_is_op_input(transpose_type_, "X")
}); ->assert_is_op_output("stack", "Y");
auto* transpose = auto* y1 = VarNode("y1")->assert_is_op_output(transpose_type_, "Out");
OpNode("transpose_op", transpose_type_) auto* out = VarNode("out")->assert_is_op_output(reshape_type_, "Out");
->assert_op_attr_satisfied<std::vector<int>>(
"axis", [](const std::vector<int>& attr) { PMNode* xshape1 = nullptr;
return attr.size() >= 5 && attr[1] == 2 && attr[2] == 1; PMNode* xshape2 = nullptr;
}); if (transpose_type_ == "transpose2") {
auto* reshape2 = OpNode("reshape2", reshape_type_) xshape1 =
->assert_op_attr_satisfied<std::vector<int>>( VarNode("xshape2")->assert_is_op_output(transpose_type_, "XShape");
"shape", [](const std::vector<int>& attr) { }
return attr.size() >= 4; if (reshape_type_ == "reshape2") {
}); xshape2 =
VarNode("xshape1")->assert_is_op_output(reshape_type_, "XShape");
// create topology. }
*x1 >> *reshape1 >> *y1 >> *transpose >> *y2 >> *reshape2 >> *out;
if (xshape1) *reshape1 >> *xshape1; auto* stack = OpNode("stack_op", "stack")
if (xshape2) *transpose >> *xshape2; ->assert_op_attr_satisfied<int>(
if (xshape3) *reshape2 >> *xshape3; "axis", [](const int& attr) { return attr == 1; });
auto* transpose =
// Some op specialities. OpNode("transpose_op", transpose_type_)
y1->AsIntermediate(); ->assert_op_attr_satisfied<std::vector<int>>(
y2->AsIntermediate(); "axis", [](const std::vector<int>& attr) {
if (xshape1) xshape1->AsIntermediate(); return attr.size() >= 5 && attr[1] == 2 && attr[2] == 1;
if (xshape2) xshape2->AsIntermediate(); });
if (xshape3) xshape3->AsIntermediate(); auto* reshape = OpNode("reshape_op", reshape_type_)
reshape1->AsIntermediate(); ->assert_op_attr_satisfied<std::vector<int>>(
transpose->AsIntermediate(); "shape", [](const std::vector<int>& attr) {
reshape2->AsIntermediate(); return attr.size() >= 4;
});
// create topology.
*stack >> *x1 >> *transpose >> *y1 >> *reshape >> *out;
if (xshape1) *transpose >> *xshape1;
if (xshape2) *reshape >> *xshape2;
// Some op specialities.
y1->AsIntermediate();
if (xshape1) xshape1->AsIntermediate();
if (xshape2) xshape2->AsIntermediate();
transpose->AsIntermediate();
reshape->AsIntermediate();
}
} }
void ShuffleChannelFuser::InsertNewNode(SSAGraph* graph, void ShuffleChannelFuser::InsertNewNode(SSAGraph* graph,
...@@ -95,11 +146,17 @@ cpp::OpDesc ShuffleChannelFuser::GenOpDesc(const key2nodes_t& matched) { ...@@ -95,11 +146,17 @@ cpp::OpDesc ShuffleChannelFuser::GenOpDesc(const key2nodes_t& matched) {
op_desc.SetType("shuffle_channel"); op_desc.SetType("shuffle_channel");
op_desc.SetInput("X", {matched.at("x1")->arg()->name}); op_desc.SetInput("X", {matched.at("x1")->arg()->name});
op_desc.SetOutput("Out", {matched.at("out")->arg()->name}); op_desc.SetOutput("Out", {matched.at("out")->arg()->name});
op_desc.SetAttr("group", int group = 1;
matched.at("reshape1") if (sub_structure_ == "r_t_r") {
->stmt() group = matched.at("reshape1")
->op_info() ->stmt()
->GetAttr<std::vector<int>>("shape")[1]); ->op_info()
->GetAttr<std::vector<int>>("shape")[1];
}
if (sub_structure_ == "s_t_r") {
group = matched.at("stack_op")->inlinks.size();
}
op_desc.SetAttr("group", group);
return op_desc; return op_desc;
} }
......
...@@ -26,16 +26,25 @@ namespace fusion { ...@@ -26,16 +26,25 @@ namespace fusion {
class ShuffleChannelFuser : public FuseBase { class ShuffleChannelFuser : public FuseBase {
public: public:
explicit ShuffleChannelFuser(const std::string& reshape_type, explicit ShuffleChannelFuser(const std::string& reshape_type,
const std::string& transpose_type) const std::string& transpose_type,
: reshape_type_(reshape_type), transpose_type_(transpose_type) {} const std::string& sub_structure)
: reshape_type_(reshape_type),
transpose_type_(transpose_type),
sub_structure_(sub_structure) {}
void BuildPattern() override; void BuildPattern() override;
void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override; void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override;
private: private:
cpp::OpDesc GenOpDesc(const key2nodes_t& matched) override; cpp::OpDesc GenOpDesc(const key2nodes_t& matched) override;
// reshape or reshape2
std::string reshape_type_; std::string reshape_type_;
// transpose or transpose2
std::string transpose_type_; std::string transpose_type_;
// r_t_r or t_r
// r_t_r: reshape + transpose + reshape
// s_t_r: stack + transpose + reshape
std::string sub_structure_;
}; };
} // namespace fusion } // namespace fusion
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册