未验证 提交 855650ec 编写于 作者: J JiangHao 提交者: GitHub

only use fused_multi_transformer_cachekv_layout_trans_pass in beam search case (#54665)

上级 20bf9592
......@@ -138,6 +138,32 @@ void FusedMultiTransformerCacheKVLayoutTransPass::FillConstantReshapePass(
AddStatis(found_subgraph_count);
}
int FusedMultiTransformerCacheKVLayoutTransPass::
CountFillConstantReshapePattern(ir::Graph* graph) const {
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::PreconditionNotMet("graph should not be null."));
GraphPatternDetector gpd;
patterns::FusedMultiTransformerFillConstantPattern pattern(
gpd.mutable_pattern(), name_scope_);
int found_subgraph_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* graph) {
VLOG(4) << "handle FillConstantReshapePass";
GET_IR_NODE_FROM_SUBGRAPH(
fused_multi_transformer, fused_multi_transformer, pattern);
GET_IR_NODE_FROM_SUBGRAPH(fill_constant, fill_constant, pattern);
GET_IR_NODE_FROM_SUBGRAPH(fill_constant_out, fill_constant_out, pattern);
auto cachekv_names = fused_multi_transformer->Op()->Input("CacheKV");
if (std::count(cachekv_names.begin(),
cachekv_names.end(),
fill_constant_out->Name()) == 0)
return;
found_subgraph_count++;
};
gpd(graph, handler);
return found_subgraph_count;
}
void FusedMultiTransformerCacheKVLayoutTransPass::GatherReshapePass(
ir::Graph* graph) const {
PADDLE_ENFORCE_NOT_NULL(
......@@ -183,14 +209,53 @@ void FusedMultiTransformerCacheKVLayoutTransPass::GatherReshapePass(
AddStatis(found_subgraph_count);
}
int FusedMultiTransformerCacheKVLayoutTransPass::CountGatherReshapePattern(
ir::Graph* graph) const {
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::PreconditionNotMet("graph should not be null."));
GraphPatternDetector gpd;
patterns::FusedMultiTransformerGatherPattern pattern(gpd.mutable_pattern(),
name_scope_);
int found_subgraph_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* graph) {
VLOG(4) << "handle GatherReshapePass";
GET_IR_NODE_FROM_SUBGRAPH(gather, gather, pattern);
GET_IR_NODE_FROM_SUBGRAPH(
fused_multi_transformer, fused_multi_transformer, pattern);
GET_IR_NODE_FROM_SUBGRAPH(gather_in, gather_in, pattern);
GET_IR_NODE_FROM_SUBGRAPH(gather_out, gather_out, pattern);
auto cachekv_names = fused_multi_transformer->Op()->Input("CacheKV");
if (std::count(cachekv_names.begin(),
cachekv_names.end(),
gather_out->Name()) == 0)
return;
found_subgraph_count++;
};
gpd(graph, handler);
return found_subgraph_count;
}
void FusedMultiTransformerCacheKVLayoutTransPass::ApplyImpl(
ir::Graph* graph) const {
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::PreconditionNotMet("graph should not be null."));
if (!graph->IsMainGraph()) {
VLOG(3) << "'fused_multi_transformer_cachekv_layout_pass' needs info in "
"all graphs, so it should be applied in the main graph.";
return;
}
Init(name_scope_, graph);
FillConstantReshapePass(graph);
GatherReshapePass(graph);
int pattern_cnt0 = 0, pattern_cnt1 = 0;
for (size_t i = 0; i < graph->SubGraphsSize(); i++) {
pattern_cnt0 += CountFillConstantReshapePattern(graph->GetSubGraph(i));
pattern_cnt1 += CountGatherReshapePattern(graph->GetSubGraph(i));
}
if (pattern_cnt0 != 0 && pattern_cnt1 != 0 && pattern_cnt0 == pattern_cnt1) {
FillConstantReshapePass(graph);
GatherReshapePass(graph);
}
}
} // namespace ir
......
......@@ -53,6 +53,8 @@ class FusedMultiTransformerCacheKVLayoutTransPass : public FusePassBase {
*/
void FillConstantReshapePass(ir::Graph* graph) const;
int CountFillConstantReshapePattern(ir::Graph* graph) const;
/*
Origin subgraph:
(gather_x: [d0,d1,d2,d3,d4])
......@@ -70,6 +72,8 @@ class FusedMultiTransformerCacheKVLayoutTransPass : public FusePassBase {
*/
void GatherReshapePass(ir::Graph* graph) const;
int CountGatherReshapePattern(ir::Graph* graph) const;
const std::string name_scope_{
"fused_multi_transformer_cachekv_layout_trans_pass"};
};
......
......@@ -73,19 +73,19 @@ TEST(FillConstantReshapePass, basic) {
pass->Apply(graph.get());
auto fills = GetOpNodes(graph, "fill_constant");
auto fill0_in_names = fills[0]->Op()->Input("ShapeTensorList");
std::vector<std::string> expect_fill0_in_names{
"shape0", "shape3", "shape1", "shape2", "shape4"};
std::vector<std::string> expect_fill0_out_names{
"shape5", "shape6", "shape7", "shape8", "shape9"};
std::vector<std::string> expect_fill1_out_names{
"shape0", "shape1", "shape2", "shape3", "shape4"};
PADDLE_ENFORCE_EQ(fill0_in_names,
expect_fill0_in_names,
expect_fill0_out_names,
platform::errors::PreconditionNotMet(
"fill_constant name should be updated."));
"fill_constant name should not be updated."));
auto fill1_in_names = fills[1]->Op()->Input("ShapeTensorList");
std::vector<std::string> expect_fill1_in_names{
"shape5", "shape8", "shape6", "shape7", "shape9"};
PADDLE_ENFORCE_EQ(fill1_in_names,
expect_fill1_in_names,
expect_fill1_out_names,
platform::errors::PreconditionNotMet(
"fill_constant name should be updated."));
"fill_constant name should not be updated."));
}
TEST(GatherReshapePass, basic) {
......@@ -109,6 +109,69 @@ TEST(GatherReshapePass, basic) {
"fused_multi_transformer_cachekv_layout_trans_pass");
pass->Apply(graph.get());
auto gathers = GetOpNodes(graph, "gather");
for (auto* gather : gathers) {
PADDLE_ENFORCE_EQ(gather->Op()->GetAttrIfExists<int>("axis"),
1,
platform::errors::PreconditionNotMet(
"gather's axis attr should not be updated by pass."));
}
}
TEST(FillConstantAndGatherReshapePass, basic) {
Layers layers;
auto* block = layers.Block();
auto* shape0 = Data(block, "shape0");
auto* shape1 = Data(block, "shape1");
auto* shape2 = Data(block, "shape2");
auto* shape3 = Data(block, "shape3");
auto* shape4 = Data(block, "shape4");
auto* shape5 = Data(block, "shape5");
auto* shape6 = Data(block, "shape6");
auto* shape7 = Data(block, "shape7");
auto* shape8 = Data(block, "shape8");
auto* shape9 = Data(block, "shape9");
auto* fill0 = fill_constant(block, {shape0, shape1, shape2, shape3, shape4});
fill0->SetShape({1, 2, 3, 4, 5});
auto* fill1 = fill_constant(block, {shape5, shape6, shape7, shape8, shape9});
fill1->SetShape({1, 2, 3, 4, 5});
OpDesc* fused_multi_transformer = block->AppendOp();
fused_multi_transformer->SetType("fused_multi_transformer");
fused_multi_transformer->SetInput("CacheKV", {fill0->Name(), fill1->Name()});
auto* gather0_x = layers.data("gather0_x", {2, 1, 24, 512, 64});
auto* gather0_index = layers.data("gather0_index", {1});
auto* gather0_out = layers.gather(gather0_x, gather0_index, 1);
gather0_out->SetShape({2, 1, 24, 512, 64});
auto* gather1_x = layers.data("gather1_x", {2, 1, 24, 512, 64});
auto* gather1_index = layers.data("gather1_index", {1});
auto* gather1_out = layers.gather(gather1_x, gather1_index, 1);
gather1_out->SetShape({2, 1, 24, 512, 64});
OpDesc* fused_multi_transformer1 = block->AppendOp();
fused_multi_transformer1->SetType("fused_multi_transformer");
fused_multi_transformer1->SetInput(
"CacheKV", {gather0_out->Name(), gather1_out->Name()});
std::unique_ptr<ir::Graph> graph(new ir::Graph(layers.main_program()));
auto pass = PassRegistry::Instance().Get(
"fused_multi_transformer_cachekv_layout_trans_pass");
pass->Apply(graph.get());
auto fills = GetOpNodes(graph, "fill_constant");
auto fill0_in_names = fills[0]->Op()->Input("ShapeTensorList");
std::vector<std::string> expect_fill0_out_names{
"shape0", "shape3", "shape1", "shape2", "shape4"};
std::vector<std::string> expect_fill1_out_names{
"shape5", "shape8", "shape6", "shape7", "shape9"};
PADDLE_ENFORCE_EQ(fill0_in_names,
expect_fill0_out_names,
platform::errors::PreconditionNotMet(
"fill_constant name should be updated."));
auto fill1_in_names = fills[1]->Op()->Input("ShapeTensorList");
PADDLE_ENFORCE_EQ(fill1_in_names,
expect_fill1_out_names,
platform::errors::PreconditionNotMet(
"fill_constant name should be updated."));
auto gathers = GetOpNodes(graph, "gather");
for (auto* gather : gathers) {
PADDLE_ENFORCE_EQ(
gather->Op()->GetAttrIfExists<int>("axis"),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册