提交 ab9c4b2a 编写于 作者: T tensor-tang

refine seqpool concat pass and remove unused nodes

test=develop
上级 ce909664
...@@ -76,6 +76,7 @@ PDNode* BuildSeqPoolConcatPattern(PDPattern* pattern, ...@@ -76,6 +76,7 @@ PDNode* BuildSeqPoolConcatPattern(PDPattern* pattern,
std::vector<PDNode*> seqpool_ops_input_var(num_inputs); std::vector<PDNode*> seqpool_ops_input_var(num_inputs);
std::vector<PDNode*> seqpool_ops_output_var(num_inputs); std::vector<PDNode*> seqpool_ops_output_var(num_inputs);
std::vector<PDNode*> seqpool_ops_output_unused_var(num_inputs);
std::vector<PDNode*> seqpool_ops(num_inputs); std::vector<PDNode*> seqpool_ops(num_inputs);
for (int i = 0; i < num_inputs; ++i) { for (int i = 0; i < num_inputs; ++i) {
...@@ -88,6 +89,15 @@ PDNode* BuildSeqPoolConcatPattern(PDPattern* pattern, ...@@ -88,6 +89,15 @@ PDNode* BuildSeqPoolConcatPattern(PDPattern* pattern,
}, },
name_scope + "/sequence_pool_out_" + std::to_string(i)); name_scope + "/sequence_pool_out_" + std::to_string(i));
seqpool_ops_output_unused_var[i] = pattern->NewNode(
[=](Node* x) {
return x && x->IsVar() && x->inputs.size() == 1 &&
x->outputs.size() == 0 &&
is_seqpool_op_with_pootype_of_nth_input_of_concat(x->inputs[0],
"SUM", i);
},
name_scope + "/sequence_pool_unused_out_" + std::to_string(i));
seqpool_ops[i] = pattern->NewNode( seqpool_ops[i] = pattern->NewNode(
[=](Node* x) { [=](Node* x) {
return x && x->IsOp() && return x && x->IsOp() &&
...@@ -97,16 +107,23 @@ PDNode* BuildSeqPoolConcatPattern(PDPattern* pattern, ...@@ -97,16 +107,23 @@ PDNode* BuildSeqPoolConcatPattern(PDPattern* pattern,
seqpool_ops_input_var[i] = pattern->NewNode( seqpool_ops_input_var[i] = pattern->NewNode(
[=](Node* x) { [=](Node* x) {
return x && x->IsVar() && x->outputs.size() >= 1 && bool basic = x && x->IsVar() && x->outputs.size() >= 1;
is_seqpool_op_with_pootype_of_nth_input_of_concat( bool next_is_fine = false;
x->outputs[0], "SUM", i); for (auto* o : x->outputs) {
if (is_seqpool_op_with_pootype_of_nth_input_of_concat(o, "SUM",
i)) {
next_is_fine = true;
break;
}
}
return basic && next_is_fine;
}, },
name_scope + "/sequence_pool_in_" + std::to_string(i)); name_scope + "/sequence_pool_in_" + std::to_string(i));
// Links // Links
seqpool_ops[i] seqpool_ops[i]
->LinksFrom({seqpool_ops_input_var[i]}) ->LinksFrom({seqpool_ops_input_var[i]})
.LinksTo({seqpool_ops_output_var[i]}); .LinksTo({seqpool_ops_output_var[i], seqpool_ops_output_unused_var[i]});
} }
concat_op->LinksFrom(seqpool_ops_output_var).LinksTo({concat_out_var}); concat_op->LinksFrom(seqpool_ops_output_var).LinksTo({concat_out_var});
return concat_out_var; return concat_out_var;
......
...@@ -35,11 +35,35 @@ void SetOp(ProgramDesc* prog, const std::string& type, ...@@ -35,11 +35,35 @@ void SetOp(ProgramDesc* prog, const std::string& type,
op->SetInput("X", inputs); op->SetInput("X", inputs);
op->SetAttr("axis", 1); op->SetAttr("axis", 1);
op->SetOutput("Out", {outputs[0]}); op->SetOutput("Out", {outputs[0]});
} else {
op->SetInput("X", inputs);
op->SetOutput("Out", outputs);
} }
op->SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(), op->SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(),
static_cast<int>(OpRole::kForward)); static_cast<int>(OpRole::kForward));
} }
int CountOpType(const ir::Graph* graph,
const std::string& op_type = "fusion_seqpool_concat") {
int count = 0;
for (auto* node : graph->Nodes()) {
if (node->IsOp() && node->Op()->Type() == op_type) {
++count;
}
}
return count;
}
std::unique_ptr<ir::Graph> GetNumNodesOfBeforeAfter(
std::unique_ptr<ir::Graph> graph, int* before, int* after,
const std::string& pass_type = "seqpool_concat_fuse_pass") {
auto pass = PassRegistry::Instance().Get(pass_type);
*before = graph->Nodes().size();
graph = pass->Apply(std::move(graph));
*after = graph->Nodes().size();
return graph;
}
/* /*
* Before fuse: * Before fuse:
* a b c * a b c
...@@ -51,15 +75,16 @@ void SetOp(ProgramDesc* prog, const std::string& type, ...@@ -51,15 +75,16 @@ void SetOp(ProgramDesc* prog, const std::string& type,
* concat * concat
* | * |
* j * j
* Type of op1, op2 and op3 are sequence_pool, with "SUM" pooltype attr
*
* After fuse: * After fuse:
* a b c * a b c
* \ | / * \ | /
* fusion_seqpool_concat * fusion_seqpool_concat
* | * |
* j * j
* unused nodes: d, f, h
*/ */
ProgramDesc BuildProgramDesc() { TEST(SeqPoolConcatFusePass, basic) {
ProgramDesc prog; ProgramDesc prog;
for (auto& v : std::vector<std::string>( for (auto& v : std::vector<std::string>(
{"a", "b", "c", "d", "e", "f", "g", "h", "i", "j"})) { {"a", "b", "c", "d", "e", "f", "g", "h", "i", "j"})) {
...@@ -76,35 +101,94 @@ ProgramDesc BuildProgramDesc() { ...@@ -76,35 +101,94 @@ ProgramDesc BuildProgramDesc() {
SetOp(&prog, "concat", std::vector<std::string>({"e", "g", "i"}), SetOp(&prog, "concat", std::vector<std::string>({"e", "g", "i"}),
std::vector<std::string>({"j"})); std::vector<std::string>({"j"}));
return prog;
}
TEST(SeqPoolConcatFusePass, basic) {
auto prog = BuildProgramDesc();
std::unique_ptr<ir::Graph> graph(new ir::Graph(prog)); std::unique_ptr<ir::Graph> graph(new ir::Graph(prog));
int before, after;
graph = GetNumNodesOfBeforeAfter(std::move(graph), &before, &after);
// Remove 10 Nodes: op1, op2, op3, d, e, f, g, h, i, concat_op
// Add 1 Node: fusion_seqpool_concat
EXPECT_EQ(after, before - 9);
EXPECT_EQ(CountOpType(graph.get()), 1);
}
auto pass = PassRegistry::Instance().Get("seqpool_concat_fuse_pass"); /*
* Before fuse:
int pre_nodes = graph->Nodes().size(); * a b
* | / \
graph = pass->Apply(std::move(graph)); * op1 op2 op3
* / \ / \ \
* c d e f g
* \ /
* concat
* |
* h
* Type of op1 and op2 are sequence_pool, with "SUM" pooltype attr
*
* After fuse:
* a b
* \ / \
* fusion_seqpool_concat op3
* | |
* h g
*/
TEST(SeqPoolConcatFusePass, advanced) {
ProgramDesc prog;
for (auto& v :
std::vector<std::string>({"a", "b", "c", "d", "e", "f", "g", "h"})) {
auto* var = prog.MutableBlock(0)->Var(v);
var->SetType(proto::VarType::LOD_TENSOR);
}
int after_nodes = graph->Nodes().size(); SetOp(&prog, "sequence_pool", std::vector<std::string>({"a"}),
std::vector<std::string>({"c", "d"}));
SetOp(&prog, "sequence_pool", std::vector<std::string>({"b"}),
std::vector<std::string>({"e", "f"}));
SetOp(&prog, "op3", std::vector<std::string>({"b"}),
std::vector<std::string>({"g"}));
SetOp(&prog, "concat", std::vector<std::string>({"d", "f"}),
std::vector<std::string>({"h"}));
// Remove 7 Nodes: op1, op2, op3, e, g, i, concat_op std::unique_ptr<ir::Graph> graph(new ir::Graph(prog));
int before, after;
graph = GetNumNodesOfBeforeAfter(std::move(graph), &before, &after);
// Remove 7 Nodes: op1, op2, c, d, e, f concat_op
// Add 1 Node: fusion_seqpool_concat // Add 1 Node: fusion_seqpool_concat
EXPECT_EQ(pre_nodes - 6, after_nodes); EXPECT_EQ(after, before - 6);
EXPECT_EQ(CountOpType(graph.get()), 1);
}
// Assert new op in newly generated graph ProgramDesc BuildProgramDesc(int num_inputs_of_concat) {
int count = 0; ProgramDesc prog;
auto new_var = [&](const std::string& name) {
auto* var = prog.MutableBlock(0)->Var(name);
var->SetType(proto::VarType::LOD_TENSOR);
};
std::vector<std::string> concat_inputs;
for (int i = 0; i < num_inputs_of_concat; ++i) {
std::string prefix = "seqpool_op_" + i;
new_var(prefix + "in");
new_var(prefix + "out");
new_var(prefix + "out_unused");
SetOp(&prog, "sequence_pool", std::vector<std::string>({prefix + "in"}),
std::vector<std::string>({prefix + "out", prefix + "out_unused"}));
concat_inputs.push_back(prefix + "out");
}
SetOp(&prog, "concat", concat_inputs,
std::vector<std::string>({"concat_out"}));
return prog;
}
for (auto* node : graph->Nodes()) { // test more inputs of concat
if (node->IsOp() && node->Op()->Type() == "fusion_seqpool_concat") { TEST(SeqPoolConcatFusePass, more_inputs) {
++count; for (int num : {1, 2, 10}) {
} ProgramDesc prog = BuildProgramDesc(num);
std::unique_ptr<ir::Graph> graph(new ir::Graph(prog));
int before, after;
graph = GetNumNodesOfBeforeAfter(std::move(graph), &before, &after);
// Remove Nodes: n * (seqpool_op, out, out_unused), and concat_op
// Add Node: fusion_seqpool_concat op
EXPECT_EQ(after, before - num * 3);
EXPECT_EQ(CountOpType(graph.get()), 1);
} }
EXPECT_EQ(count, 1);
} }
} // namespace ir } // namespace ir
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册