diff --git a/paddle/fluid/framework/ir/xpu/fused_multi_transformer_cachekv_layout_trans_pass.cc b/paddle/fluid/framework/ir/xpu/fused_multi_transformer_cachekv_layout_trans_pass.cc index 993b5a055280d869dea26e9a27b4b5c860717495..fb86ed4d22c9ffd220eb7e63509139e8f26d699d 100644 --- a/paddle/fluid/framework/ir/xpu/fused_multi_transformer_cachekv_layout_trans_pass.cc +++ b/paddle/fluid/framework/ir/xpu/fused_multi_transformer_cachekv_layout_trans_pass.cc @@ -138,6 +138,32 @@ void FusedMultiTransformerCacheKVLayoutTransPass::FillConstantReshapePass( AddStatis(found_subgraph_count); } +int FusedMultiTransformerCacheKVLayoutTransPass:: + CountFillConstantReshapePattern(ir::Graph* graph) const { + PADDLE_ENFORCE_NOT_NULL( + graph, platform::errors::PreconditionNotMet("graph should not be null.")); + GraphPatternDetector gpd; + patterns::FusedMultiTransformerFillConstantPattern pattern( + gpd.mutable_pattern(), name_scope_); + int found_subgraph_count = 0; + auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* graph) { + VLOG(4) << "handle FillConstantReshapePass"; + GET_IR_NODE_FROM_SUBGRAPH( + fused_multi_transformer, fused_multi_transformer, pattern); + GET_IR_NODE_FROM_SUBGRAPH(fill_constant, fill_constant, pattern); + GET_IR_NODE_FROM_SUBGRAPH(fill_constant_out, fill_constant_out, pattern); + auto cachekv_names = fused_multi_transformer->Op()->Input("CacheKV"); + if (std::count(cachekv_names.begin(), + cachekv_names.end(), + fill_constant_out->Name()) == 0) + return; + found_subgraph_count++; + }; + gpd(graph, handler); + return found_subgraph_count; +} + void FusedMultiTransformerCacheKVLayoutTransPass::GatherReshapePass( ir::Graph* graph) const { PADDLE_ENFORCE_NOT_NULL( @@ -183,14 +209,53 @@ void FusedMultiTransformerCacheKVLayoutTransPass::GatherReshapePass( AddStatis(found_subgraph_count); } +int FusedMultiTransformerCacheKVLayoutTransPass::CountGatherReshapePattern( + ir::Graph* graph) const { + PADDLE_ENFORCE_NOT_NULL( + graph, platform::errors::PreconditionNotMet("graph should not be null.")); + GraphPatternDetector gpd; + patterns::FusedMultiTransformerGatherPattern pattern(gpd.mutable_pattern(), + name_scope_); + + int found_subgraph_count = 0; + auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* graph) { + VLOG(4) << "handle GatherReshapePass"; + GET_IR_NODE_FROM_SUBGRAPH(gather, gather, pattern); + GET_IR_NODE_FROM_SUBGRAPH( + fused_multi_transformer, fused_multi_transformer, pattern); + GET_IR_NODE_FROM_SUBGRAPH(gather_in, gather_in, pattern); + GET_IR_NODE_FROM_SUBGRAPH(gather_out, gather_out, pattern); + auto cachekv_names = fused_multi_transformer->Op()->Input("CacheKV"); + if (std::count(cachekv_names.begin(), + cachekv_names.end(), + gather_out->Name()) == 0) + return; + found_subgraph_count++; + }; + gpd(graph, handler); + return found_subgraph_count; +} + void FusedMultiTransformerCacheKVLayoutTransPass::ApplyImpl( ir::Graph* graph) const { PADDLE_ENFORCE_NOT_NULL( graph, platform::errors::PreconditionNotMet("graph should not be null.")); + if (!graph->IsMainGraph()) { + VLOG(3) << "'fused_multi_transformer_cachekv_layout_pass' needs info in " + "all graphs, so it should be applied in the main graph."; + return; + } Init(name_scope_, graph); - - FillConstantReshapePass(graph); - GatherReshapePass(graph); + int pattern_cnt0 = 0, pattern_cnt1 = 0; + for (size_t i = 0; i < graph->SubGraphsSize(); i++) { + pattern_cnt0 += CountFillConstantReshapePattern(graph->GetSubGraph(i)); + pattern_cnt1 += CountGatherReshapePattern(graph->GetSubGraph(i)); + } + if (pattern_cnt0 != 0 && pattern_cnt1 != 0 && pattern_cnt0 == pattern_cnt1) { + FillConstantReshapePass(graph); + GatherReshapePass(graph); + } } } // namespace ir diff --git a/paddle/fluid/framework/ir/xpu/fused_multi_transformer_cachekv_layout_trans_pass.h b/paddle/fluid/framework/ir/xpu/fused_multi_transformer_cachekv_layout_trans_pass.h index cb87317a76e6a040c97738f36f4707b1a1191b43..c691c055779f3dbd1a5bc99682c6c8e2674b96c2 100644 --- a/paddle/fluid/framework/ir/xpu/fused_multi_transformer_cachekv_layout_trans_pass.h +++ b/paddle/fluid/framework/ir/xpu/fused_multi_transformer_cachekv_layout_trans_pass.h @@ -53,6 +53,8 @@ class FusedMultiTransformerCacheKVLayoutTransPass : public FusePassBase { */ void FillConstantReshapePass(ir::Graph* graph) const; + int CountFillConstantReshapePattern(ir::Graph* graph) const; + /* Origin subgraph: (gather_x: [d0,d1,d2,d3,d4]) @@ -70,6 +72,8 @@ class FusedMultiTransformerCacheKVLayoutTransPass : public FusePassBase { */ void GatherReshapePass(ir::Graph* graph) const; + int CountGatherReshapePattern(ir::Graph* graph) const; + const std::string name_scope_{ "fused_multi_transformer_cachekv_layout_trans_pass"}; }; diff --git a/paddle/fluid/framework/ir/xpu/fused_multi_transformer_cachekv_layout_trans_pass_test.cc b/paddle/fluid/framework/ir/xpu/fused_multi_transformer_cachekv_layout_trans_pass_test.cc index ec5dba201163fffbbc6fa4182537c73a6e25083c..c846ef3fbea2f98b6ecbeac864fcbf04b29a3a4e 100644 --- a/paddle/fluid/framework/ir/xpu/fused_multi_transformer_cachekv_layout_trans_pass_test.cc +++ b/paddle/fluid/framework/ir/xpu/fused_multi_transformer_cachekv_layout_trans_pass_test.cc @@ -73,19 +73,19 @@ TEST(FillConstantReshapePass, basic) { pass->Apply(graph.get()); auto fills = GetOpNodes(graph, "fill_constant"); auto fill0_in_names = fills[0]->Op()->Input("ShapeTensorList"); - std::vector expect_fill0_in_names{ - "shape0", "shape3", "shape1", "shape2", "shape4"}; + std::vector expect_fill0_out_names{ + "shape5", "shape6", "shape7", "shape8", "shape9"}; + std::vector expect_fill1_out_names{ + "shape0", "shape1", "shape2", "shape3", "shape4"}; PADDLE_ENFORCE_EQ(fill0_in_names, - expect_fill0_in_names, + expect_fill0_out_names, platform::errors::PreconditionNotMet( - "fill_constant name should be updated.")); + "fill_constant name should not be updated.")); auto fill1_in_names = fills[1]->Op()->Input("ShapeTensorList"); - std::vector expect_fill1_in_names{ - "shape5", "shape8", "shape6", "shape7", "shape9"}; PADDLE_ENFORCE_EQ(fill1_in_names, - expect_fill1_in_names, + expect_fill1_out_names, platform::errors::PreconditionNotMet( - "fill_constant name should be updated.")); + "fill_constant name should not be updated.")); } TEST(GatherReshapePass, basic) { @@ -109,6 +109,69 @@ TEST(GatherReshapePass, basic) { "fused_multi_transformer_cachekv_layout_trans_pass"); pass->Apply(graph.get()); auto gathers = GetOpNodes(graph, "gather"); + for (auto* gather : gathers) { + PADDLE_ENFORCE_EQ(gather->Op()->GetAttrIfExists("axis"), + 1, + platform::errors::PreconditionNotMet( + "gather's axis attr should not be updated by pass.")); + } +} + +TEST(FillConstantAndGatherReshapePass, basic) { + Layers layers; + auto* block = layers.Block(); + auto* shape0 = Data(block, "shape0"); + auto* shape1 = Data(block, "shape1"); + auto* shape2 = Data(block, "shape2"); + auto* shape3 = Data(block, "shape3"); + auto* shape4 = Data(block, "shape4"); + auto* shape5 = Data(block, "shape5"); + auto* shape6 = Data(block, "shape6"); + auto* shape7 = Data(block, "shape7"); + auto* shape8 = Data(block, "shape8"); + auto* shape9 = Data(block, "shape9"); + auto* fill0 = fill_constant(block, {shape0, shape1, shape2, shape3, shape4}); + fill0->SetShape({1, 2, 3, 4, 5}); + auto* fill1 = fill_constant(block, {shape5, shape6, shape7, shape8, shape9}); + fill1->SetShape({1, 2, 3, 4, 5}); + OpDesc* fused_multi_transformer = block->AppendOp(); + fused_multi_transformer->SetType("fused_multi_transformer"); + fused_multi_transformer->SetInput("CacheKV", {fill0->Name(), fill1->Name()}); + + auto* gather0_x = layers.data("gather0_x", {2, 1, 24, 512, 64}); + auto* gather0_index = layers.data("gather0_index", {1}); + auto* gather0_out = layers.gather(gather0_x, gather0_index, 1); + gather0_out->SetShape({2, 1, 24, 512, 64}); + auto* gather1_x = layers.data("gather1_x", {2, 1, 24, 512, 64}); + auto* gather1_index = layers.data("gather1_index", {1}); + auto* gather1_out = layers.gather(gather1_x, gather1_index, 1); + gather1_out->SetShape({2, 1, 24, 512, 64}); + OpDesc* fused_multi_transformer1 = block->AppendOp(); + fused_multi_transformer1->SetType("fused_multi_transformer"); + fused_multi_transformer1->SetInput( + "CacheKV", {gather0_out->Name(), gather1_out->Name()}); + + std::unique_ptr graph(new ir::Graph(layers.main_program())); + auto pass = PassRegistry::Instance().Get( + "fused_multi_transformer_cachekv_layout_trans_pass"); + pass->Apply(graph.get()); + + auto fills = GetOpNodes(graph, "fill_constant"); + auto fill0_in_names = fills[0]->Op()->Input("ShapeTensorList"); + std::vector expect_fill0_out_names{ + "shape0", "shape3", "shape1", "shape2", "shape4"}; + std::vector expect_fill1_out_names{ + "shape5", "shape8", "shape6", "shape7", "shape9"}; + PADDLE_ENFORCE_EQ(fill0_in_names, + expect_fill0_out_names, + platform::errors::PreconditionNotMet( + "fill_constant name should be updated.")); + auto fill1_in_names = fills[1]->Op()->Input("ShapeTensorList"); + PADDLE_ENFORCE_EQ(fill1_in_names, + expect_fill1_out_names, + platform::errors::PreconditionNotMet( + "fill_constant name should be updated.")); + auto gathers = GetOpNodes(graph, "gather"); for (auto* gather : gathers) { PADDLE_ENFORCE_EQ( gather->Op()->GetAttrIfExists("axis"),