提交 1af0e6c3 编写于 作者: Z zhupengyang 提交者: GitHub

enhance shuffle-channel fuse pass (#3208)

上级 fcc6b2da
......@@ -23,11 +23,15 @@ namespace lite {
namespace mir {
void ShuffleChannelFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) {
fusion::ShuffleChannelFuser fuser("reshape", "transpose");
fuser(graph.get());
fusion::ShuffleChannelFuser fuser2("reshape2", "transpose2");
fuser2(graph.get());
for (std::string reshape_type : {"reshape", "reshape2"}) {
for (std::string transpose_type : {"transpose", "transpose2"}) {
for (std::string sub_structure : {"r_t_r", "s_t_r"}) {
fusion::ShuffleChannelFuser fuser(
reshape_type, transpose_type, sub_structure);
fuser(graph.get());
}
}
}
}
} // namespace mir
......
......@@ -22,56 +22,107 @@ namespace mir {
namespace fusion {
void ShuffleChannelFuser::BuildPattern() {
// create nodes.
auto* x1 = VarNode("x1")->assert_is_op_input(reshape_type_, "X");
auto* y1 = VarNode("y1")->assert_is_op_output(reshape_type_, "Out");
auto* y2 = VarNode("y2")->assert_is_op_output(transpose_type_, "Out");
auto* out = VarNode("out")->assert_is_op_output(reshape_type_, "Out");
PMNode* xshape1 = nullptr;
PMNode* xshape2 = nullptr;
PMNode* xshape3 = nullptr;
if (reshape_type_ == "reshape2") {
xshape1 = VarNode("xshape1")->assert_is_op_output(reshape_type_, "XShape");
xshape3 = VarNode("xshape3")->assert_is_op_output(reshape_type_, "XShape");
}
if (transpose_type_ == "transpose2") {
xshape2 =
VarNode("xshape2")->assert_is_op_output(transpose_type_, "XShape");
if (sub_structure_ == "r_t_r") {
// create nodes.
auto* x1 = VarNode("x1")->assert_is_op_input(reshape_type_, "X");
auto* y1 = VarNode("y1")->assert_is_op_output(reshape_type_, "Out");
auto* y2 = VarNode("y2")->assert_is_op_output(transpose_type_, "Out");
auto* out = VarNode("out")->assert_is_op_output(reshape_type_, "Out");
PMNode* xshape1 = nullptr;
PMNode* xshape2 = nullptr;
PMNode* xshape3 = nullptr;
if (reshape_type_ == "reshape2") {
xshape1 =
VarNode("xshape1")->assert_is_op_output(reshape_type_, "XShape");
xshape3 =
VarNode("xshape3")->assert_is_op_output(reshape_type_, "XShape");
}
if (transpose_type_ == "transpose2") {
xshape2 =
VarNode("xshape2")->assert_is_op_output(transpose_type_, "XShape");
}
auto* reshape1 = OpNode("reshape1", reshape_type_)
->assert_op_attr_satisfied<std::vector<int>>(
"shape", [](const std::vector<int>& attr) {
return attr.size() >= 5 && attr[1] > 0;
});
auto* transpose =
OpNode("transpose_op", transpose_type_)
->assert_op_attr_satisfied<std::vector<int>>(
"axis", [](const std::vector<int>& attr) {
return attr.size() >= 5 && attr[1] == 2 && attr[2] == 1;
});
auto* reshape2 = OpNode("reshape2", reshape_type_)
->assert_op_attr_satisfied<std::vector<int>>(
"shape", [](const std::vector<int>& attr) {
return attr.size() >= 4;
});
// create topology.
*x1 >> *reshape1 >> *y1 >> *transpose >> *y2 >> *reshape2 >> *out;
if (xshape1) *reshape1 >> *xshape1;
if (xshape2) *transpose >> *xshape2;
if (xshape3) *reshape2 >> *xshape3;
// Some op specialities.
y1->AsIntermediate();
y2->AsIntermediate();
if (xshape1) xshape1->AsIntermediate();
if (xshape2) xshape2->AsIntermediate();
if (xshape3) xshape3->AsIntermediate();
reshape1->AsIntermediate();
transpose->AsIntermediate();
reshape2->AsIntermediate();
}
auto* reshape1 = OpNode("reshape1", reshape_type_)
->assert_op_attr_satisfied<std::vector<int>>(
"shape", [](const std::vector<int>& attr) {
return attr.size() >= 5 && attr[1] > 0;
});
auto* transpose =
OpNode("transpose_op", transpose_type_)
->assert_op_attr_satisfied<std::vector<int>>(
"axis", [](const std::vector<int>& attr) {
return attr.size() >= 5 && attr[1] == 2 && attr[2] == 1;
});
auto* reshape2 = OpNode("reshape2", reshape_type_)
->assert_op_attr_satisfied<std::vector<int>>(
"shape", [](const std::vector<int>& attr) {
return attr.size() >= 4;
});
// create topology.
*x1 >> *reshape1 >> *y1 >> *transpose >> *y2 >> *reshape2 >> *out;
if (xshape1) *reshape1 >> *xshape1;
if (xshape2) *transpose >> *xshape2;
if (xshape3) *reshape2 >> *xshape3;
// Some op specialities.
y1->AsIntermediate();
y2->AsIntermediate();
if (xshape1) xshape1->AsIntermediate();
if (xshape2) xshape2->AsIntermediate();
if (xshape3) xshape3->AsIntermediate();
reshape1->AsIntermediate();
transpose->AsIntermediate();
reshape2->AsIntermediate();
if (sub_structure_ == "s_t_r") {
// create nodes.
auto* x1 = VarNode("x1")
->assert_is_op_input(transpose_type_, "X")
->assert_is_op_output("stack", "Y");
auto* y1 = VarNode("y1")->assert_is_op_output(transpose_type_, "Out");
auto* out = VarNode("out")->assert_is_op_output(reshape_type_, "Out");
PMNode* xshape1 = nullptr;
PMNode* xshape2 = nullptr;
if (transpose_type_ == "transpose2") {
xshape1 =
VarNode("xshape2")->assert_is_op_output(transpose_type_, "XShape");
}
if (reshape_type_ == "reshape2") {
xshape2 =
VarNode("xshape1")->assert_is_op_output(reshape_type_, "XShape");
}
auto* stack = OpNode("stack_op", "stack")
->assert_op_attr_satisfied<int>(
"axis", [](const int& attr) { return attr == 1; });
auto* transpose =
OpNode("transpose_op", transpose_type_)
->assert_op_attr_satisfied<std::vector<int>>(
"axis", [](const std::vector<int>& attr) {
return attr.size() >= 5 && attr[1] == 2 && attr[2] == 1;
});
auto* reshape = OpNode("reshape_op", reshape_type_)
->assert_op_attr_satisfied<std::vector<int>>(
"shape", [](const std::vector<int>& attr) {
return attr.size() >= 4;
});
// create topology.
*stack >> *x1 >> *transpose >> *y1 >> *reshape >> *out;
if (xshape1) *transpose >> *xshape1;
if (xshape2) *reshape >> *xshape2;
// Some op specialities.
y1->AsIntermediate();
if (xshape1) xshape1->AsIntermediate();
if (xshape2) xshape2->AsIntermediate();
transpose->AsIntermediate();
reshape->AsIntermediate();
}
}
void ShuffleChannelFuser::InsertNewNode(SSAGraph* graph,
......@@ -95,11 +146,17 @@ cpp::OpDesc ShuffleChannelFuser::GenOpDesc(const key2nodes_t& matched) {
op_desc.SetType("shuffle_channel");
op_desc.SetInput("X", {matched.at("x1")->arg()->name});
op_desc.SetOutput("Out", {matched.at("out")->arg()->name});
op_desc.SetAttr("group",
matched.at("reshape1")
->stmt()
->op_info()
->GetAttr<std::vector<int>>("shape")[1]);
int group = 1;
if (sub_structure_ == "r_t_r") {
group = matched.at("reshape1")
->stmt()
->op_info()
->GetAttr<std::vector<int>>("shape")[1];
}
if (sub_structure_ == "s_t_r") {
group = matched.at("stack_op")->inlinks.size();
}
op_desc.SetAttr("group", group);
return op_desc;
}
......
......@@ -26,16 +26,25 @@ namespace fusion {
class ShuffleChannelFuser : public FuseBase {
public:
explicit ShuffleChannelFuser(const std::string& reshape_type,
const std::string& transpose_type)
: reshape_type_(reshape_type), transpose_type_(transpose_type) {}
const std::string& transpose_type,
const std::string& sub_structure)
: reshape_type_(reshape_type),
transpose_type_(transpose_type),
sub_structure_(sub_structure) {}
void BuildPattern() override;
void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override;
private:
cpp::OpDesc GenOpDesc(const key2nodes_t& matched) override;
// reshape or reshape2
std::string reshape_type_;
// transpose or transpose2
std::string transpose_type_;
// r_t_r or t_r
// r_t_r: reshape + transpose + reshape
// s_t_r: stack + transpose + reshape
std::string sub_structure_;
};
} // namespace fusion
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册