未验证 提交 19d6a988 编写于 作者: J JiangHao 提交者: GitHub

fix cachek_kv_layout pass (#54994)

上级 2ee5b296
......@@ -149,10 +149,9 @@ int FusedMultiTransformerCacheKVLayoutTransPass::
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* graph) {
VLOG(4) << "handle FillConstantReshapePass";
GET_IR_NODE_FROM_SUBGRAPH(
fused_multi_transformer, fused_multi_transformer, pattern);
GET_IR_NODE_FROM_SUBGRAPH(fill_constant, fill_constant, pattern);
GET_IR_NODE_FROM_SUBGRAPH(fill_constant_out, fill_constant_out, pattern);
GET_IR_NODE(fused_multi_transformer);
GET_IR_NODE(fill_constant);
GET_IR_NODE(fill_constant_out);
auto cachekv_names = fused_multi_transformer->Op()->Input("CacheKV");
if (std::count(cachekv_names.begin(),
cachekv_names.end(),
......@@ -221,11 +220,10 @@ int FusedMultiTransformerCacheKVLayoutTransPass::CountGatherReshapePattern(
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* graph) {
VLOG(4) << "handle GatherReshapePass";
GET_IR_NODE_FROM_SUBGRAPH(gather, gather, pattern);
GET_IR_NODE_FROM_SUBGRAPH(
fused_multi_transformer, fused_multi_transformer, pattern);
GET_IR_NODE_FROM_SUBGRAPH(gather_in, gather_in, pattern);
GET_IR_NODE_FROM_SUBGRAPH(gather_out, gather_out, pattern);
GET_IR_NODE(gather);
GET_IR_NODE(fused_multi_transformer);
GET_IR_NODE(gather_in);
GET_IR_NODE(gather_out);
auto cachekv_names = fused_multi_transformer->Op()->Input("CacheKV");
if (std::count(cachekv_names.begin(),
cachekv_names.end(),
......@@ -253,8 +251,10 @@ void FusedMultiTransformerCacheKVLayoutTransPass::ApplyImpl(
pattern_cnt1 += CountGatherReshapePattern(graph->GetSubGraph(i));
}
if (pattern_cnt0 != 0 && pattern_cnt1 != 0 && pattern_cnt0 == pattern_cnt1) {
FillConstantReshapePass(graph);
GatherReshapePass(graph);
for (size_t i = 0; i < graph->SubGraphsSize(); i++) {
FillConstantReshapePass(graph->GetSubGraph(i));
GatherReshapePass(graph->GetSubGraph(i));
}
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册