未验证 提交 db96ae58 编写于 作者: 周周周 提交者: GitHub

[paddle-trt] x and y 's rank should be same in trt_skip_layernorm_pass (#56007)

* commit

* commit

---------
Co-authored-by: Nzhoukangkang <zhoukangkang@baidu.com>
上级 5ada98b8
......@@ -152,6 +152,13 @@ void TrtSkipLayerNormFusePass::ApplyImpl(ir::Graph *graph) const {
}
VLOG(4) << "handle TrtSkipLayerNorm fuse";
// x and y 's rank must be same
if (subgraph.at(x)->Var()->GetShape().size() !=
subgraph.at(y)->Var()->GetShape().size()) {
return;
}
GET_IR_NODE_FROM_SUBGRAPH(elementwise, elementwise, fused_pattern);
GET_IR_NODE_FROM_SUBGRAPH(elementwise_out, elementwise_out, fused_pattern);
GET_IR_NODE_FROM_SUBGRAPH(layer_norm, layer_norm, fused_pattern);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册