From 19d6a988e0f2873680863e4cb90119e5f20e30f0 Mon Sep 17 00:00:00 2001 From: JiangHao Date: Fri, 30 Jun 2023 09:53:07 +0800 Subject: [PATCH] fix cachek_kv_layout pass (#54994) --- ...i_transformer_cachekv_layout_trans_pass.cc | 22 +++++++++---------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/paddle/fluid/framework/ir/xpu/fused_multi_transformer_cachekv_layout_trans_pass.cc b/paddle/fluid/framework/ir/xpu/fused_multi_transformer_cachekv_layout_trans_pass.cc index fb86ed4d22c..3136cab5475 100644 --- a/paddle/fluid/framework/ir/xpu/fused_multi_transformer_cachekv_layout_trans_pass.cc +++ b/paddle/fluid/framework/ir/xpu/fused_multi_transformer_cachekv_layout_trans_pass.cc @@ -149,10 +149,9 @@ int FusedMultiTransformerCacheKVLayoutTransPass:: auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, Graph* graph) { VLOG(4) << "handle FillConstantReshapePass"; - GET_IR_NODE_FROM_SUBGRAPH( - fused_multi_transformer, fused_multi_transformer, pattern); - GET_IR_NODE_FROM_SUBGRAPH(fill_constant, fill_constant, pattern); - GET_IR_NODE_FROM_SUBGRAPH(fill_constant_out, fill_constant_out, pattern); + GET_IR_NODE(fused_multi_transformer); + GET_IR_NODE(fill_constant); + GET_IR_NODE(fill_constant_out); auto cachekv_names = fused_multi_transformer->Op()->Input("CacheKV"); if (std::count(cachekv_names.begin(), cachekv_names.end(), @@ -221,11 +220,10 @@ int FusedMultiTransformerCacheKVLayoutTransPass::CountGatherReshapePattern( auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, Graph* graph) { VLOG(4) << "handle GatherReshapePass"; - GET_IR_NODE_FROM_SUBGRAPH(gather, gather, pattern); - GET_IR_NODE_FROM_SUBGRAPH( - fused_multi_transformer, fused_multi_transformer, pattern); - GET_IR_NODE_FROM_SUBGRAPH(gather_in, gather_in, pattern); - GET_IR_NODE_FROM_SUBGRAPH(gather_out, gather_out, pattern); + GET_IR_NODE(gather); + GET_IR_NODE(fused_multi_transformer); + GET_IR_NODE(gather_in); + GET_IR_NODE(gather_out); auto cachekv_names = fused_multi_transformer->Op()->Input("CacheKV"); if (std::count(cachekv_names.begin(), cachekv_names.end(), @@ -253,8 +251,10 @@ void FusedMultiTransformerCacheKVLayoutTransPass::ApplyImpl( pattern_cnt1 += CountGatherReshapePattern(graph->GetSubGraph(i)); } if (pattern_cnt0 != 0 && pattern_cnt1 != 0 && pattern_cnt0 == pattern_cnt1) { - FillConstantReshapePass(graph); - GatherReshapePass(graph); + for (size_t i = 0; i < graph->SubGraphsSize(); i++) { + FillConstantReshapePass(graph->GetSubGraph(i)); + GatherReshapePass(graph->GetSubGraph(i)); + } } } -- GitLab