未验证 提交 2ea1c134 编写于 作者: W Wangzheee 提交者: GitHub

fix trt and gpu pass: emb_elt_layn (#44842)

上级 d94b9686
......@@ -316,6 +316,13 @@ int EmbeddingEltwiseLayerNormFusePass::BuildFusion(
embs.push_back(inner_pattern_ins[js[iter]].second->Name());
}
// todo: support any inputs with lookup_table_v2
if (ids.size() < 3) {
VLOG(3) << "embedding_eltwise_layernorm_fuse_pass only support >=3 "
"inputs with lookup_table_v2";
return fusion_count;
}
OpDesc new_op_desc;
new_op_desc.SetType("fused_embedding_eltwise_layernorm");
new_op_desc.SetInput("Ids", ids);
......
......@@ -326,6 +326,12 @@ int TrtEmbeddingEltwiseLayerNormFusePass::BuildFusion(
embs.push_back(inner_pattern_ins[js[iter]].second->Name());
}
// todo: support any inputs with lookup_table_v2
if (ids.size() < 3) {
VLOG(3) << "trt_embedding_eltwise_layernorm_fuse_pass only support >=3 "
"inputs with lookup_table_v2";
return fusion_count;
}
OpDesc new_op_desc(end_patter_layernorms[0]->Op()->Block());
new_op_desc.SetType("fused_embedding_eltwise_layernorm");
new_op_desc.SetInput("Ids", ids);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册