提交 1af0e6c3 编写于 作者: Z zhupengyang 提交者: GitHub

enhance shuffle-channel fuse pass (#3208)

上级 fcc6b2da
...@@ -23,11 +23,15 @@ namespace lite { ...@@ -23,11 +23,15 @@ namespace lite {
namespace mir { namespace mir {
void ShuffleChannelFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) { void ShuffleChannelFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) {
fusion::ShuffleChannelFuser fuser("reshape", "transpose"); for (std::string reshape_type : {"reshape", "reshape2"}) {
for (std::string transpose_type : {"transpose", "transpose2"}) {
for (std::string sub_structure : {"r_t_r", "s_t_r"}) {
fusion::ShuffleChannelFuser fuser(
reshape_type, transpose_type, sub_structure);
fuser(graph.get()); fuser(graph.get());
}
fusion::ShuffleChannelFuser fuser2("reshape2", "transpose2"); }
fuser2(graph.get()); }
} }
} // namespace mir } // namespace mir
......
...@@ -22,6 +22,7 @@ namespace mir { ...@@ -22,6 +22,7 @@ namespace mir {
namespace fusion { namespace fusion {
void ShuffleChannelFuser::BuildPattern() { void ShuffleChannelFuser::BuildPattern() {
if (sub_structure_ == "r_t_r") {
// create nodes. // create nodes.
auto* x1 = VarNode("x1")->assert_is_op_input(reshape_type_, "X"); auto* x1 = VarNode("x1")->assert_is_op_input(reshape_type_, "X");
auto* y1 = VarNode("y1")->assert_is_op_output(reshape_type_, "Out"); auto* y1 = VarNode("y1")->assert_is_op_output(reshape_type_, "Out");
...@@ -32,8 +33,10 @@ void ShuffleChannelFuser::BuildPattern() { ...@@ -32,8 +33,10 @@ void ShuffleChannelFuser::BuildPattern() {
PMNode* xshape2 = nullptr; PMNode* xshape2 = nullptr;
PMNode* xshape3 = nullptr; PMNode* xshape3 = nullptr;
if (reshape_type_ == "reshape2") { if (reshape_type_ == "reshape2") {
xshape1 = VarNode("xshape1")->assert_is_op_output(reshape_type_, "XShape"); xshape1 =
xshape3 = VarNode("xshape3")->assert_is_op_output(reshape_type_, "XShape"); VarNode("xshape1")->assert_is_op_output(reshape_type_, "XShape");
xshape3 =
VarNode("xshape3")->assert_is_op_output(reshape_type_, "XShape");
} }
if (transpose_type_ == "transpose2") { if (transpose_type_ == "transpose2") {
xshape2 = xshape2 =
...@@ -72,6 +75,54 @@ void ShuffleChannelFuser::BuildPattern() { ...@@ -72,6 +75,54 @@ void ShuffleChannelFuser::BuildPattern() {
reshape1->AsIntermediate(); reshape1->AsIntermediate();
transpose->AsIntermediate(); transpose->AsIntermediate();
reshape2->AsIntermediate(); reshape2->AsIntermediate();
}
if (sub_structure_ == "s_t_r") {
// create nodes.
auto* x1 = VarNode("x1")
->assert_is_op_input(transpose_type_, "X")
->assert_is_op_output("stack", "Y");
auto* y1 = VarNode("y1")->assert_is_op_output(transpose_type_, "Out");
auto* out = VarNode("out")->assert_is_op_output(reshape_type_, "Out");
PMNode* xshape1 = nullptr;
PMNode* xshape2 = nullptr;
if (transpose_type_ == "transpose2") {
xshape1 =
VarNode("xshape2")->assert_is_op_output(transpose_type_, "XShape");
}
if (reshape_type_ == "reshape2") {
xshape2 =
VarNode("xshape1")->assert_is_op_output(reshape_type_, "XShape");
}
auto* stack = OpNode("stack_op", "stack")
->assert_op_attr_satisfied<int>(
"axis", [](const int& attr) { return attr == 1; });
auto* transpose =
OpNode("transpose_op", transpose_type_)
->assert_op_attr_satisfied<std::vector<int>>(
"axis", [](const std::vector<int>& attr) {
return attr.size() >= 5 && attr[1] == 2 && attr[2] == 1;
});
auto* reshape = OpNode("reshape_op", reshape_type_)
->assert_op_attr_satisfied<std::vector<int>>(
"shape", [](const std::vector<int>& attr) {
return attr.size() >= 4;
});
// create topology.
*stack >> *x1 >> *transpose >> *y1 >> *reshape >> *out;
if (xshape1) *transpose >> *xshape1;
if (xshape2) *reshape >> *xshape2;
// Some op specialities.
y1->AsIntermediate();
if (xshape1) xshape1->AsIntermediate();
if (xshape2) xshape2->AsIntermediate();
transpose->AsIntermediate();
reshape->AsIntermediate();
}
} }
void ShuffleChannelFuser::InsertNewNode(SSAGraph* graph, void ShuffleChannelFuser::InsertNewNode(SSAGraph* graph,
...@@ -95,11 +146,17 @@ cpp::OpDesc ShuffleChannelFuser::GenOpDesc(const key2nodes_t& matched) { ...@@ -95,11 +146,17 @@ cpp::OpDesc ShuffleChannelFuser::GenOpDesc(const key2nodes_t& matched) {
op_desc.SetType("shuffle_channel"); op_desc.SetType("shuffle_channel");
op_desc.SetInput("X", {matched.at("x1")->arg()->name}); op_desc.SetInput("X", {matched.at("x1")->arg()->name});
op_desc.SetOutput("Out", {matched.at("out")->arg()->name}); op_desc.SetOutput("Out", {matched.at("out")->arg()->name});
op_desc.SetAttr("group", int group = 1;
matched.at("reshape1") if (sub_structure_ == "r_t_r") {
group = matched.at("reshape1")
->stmt() ->stmt()
->op_info() ->op_info()
->GetAttr<std::vector<int>>("shape")[1]); ->GetAttr<std::vector<int>>("shape")[1];
}
if (sub_structure_ == "s_t_r") {
group = matched.at("stack_op")->inlinks.size();
}
op_desc.SetAttr("group", group);
return op_desc; return op_desc;
} }
......
...@@ -26,16 +26,25 @@ namespace fusion { ...@@ -26,16 +26,25 @@ namespace fusion {
class ShuffleChannelFuser : public FuseBase { class ShuffleChannelFuser : public FuseBase {
public: public:
explicit ShuffleChannelFuser(const std::string& reshape_type, explicit ShuffleChannelFuser(const std::string& reshape_type,
const std::string& transpose_type) const std::string& transpose_type,
: reshape_type_(reshape_type), transpose_type_(transpose_type) {} const std::string& sub_structure)
: reshape_type_(reshape_type),
transpose_type_(transpose_type),
sub_structure_(sub_structure) {}
void BuildPattern() override; void BuildPattern() override;
void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override; void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override;
private: private:
cpp::OpDesc GenOpDesc(const key2nodes_t& matched) override; cpp::OpDesc GenOpDesc(const key2nodes_t& matched) override;
// reshape or reshape2
std::string reshape_type_; std::string reshape_type_;
// transpose or transpose2
std::string transpose_type_; std::string transpose_type_;
// r_t_r or t_r
// r_t_r: reshape + transpose + reshape
// s_t_r: stack + transpose + reshape
std::string sub_structure_;
}; };
} // namespace fusion } // namespace fusion
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册