From db96ae5891be39a9134963cec58792dbfe83d411 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=91=A8=E5=91=A8=E5=91=A8?= <39978853+zhoutianzi666@users.noreply.github.com> Date: Mon, 7 Aug 2023 16:30:37 +0800 Subject: [PATCH] [paddle-trt] x and y 's rank should be same in trt_skip_layernorm_pass (#56007) * commit * commit --------- Co-authored-by: zhoukangkang --- paddle/fluid/framework/ir/trt_skip_layernorm_fuse_pass.cc | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/paddle/fluid/framework/ir/trt_skip_layernorm_fuse_pass.cc b/paddle/fluid/framework/ir/trt_skip_layernorm_fuse_pass.cc index 65afdec4c05..81f96f2fc33 100644 --- a/paddle/fluid/framework/ir/trt_skip_layernorm_fuse_pass.cc +++ b/paddle/fluid/framework/ir/trt_skip_layernorm_fuse_pass.cc @@ -152,6 +152,13 @@ void TrtSkipLayerNormFusePass::ApplyImpl(ir::Graph *graph) const { } VLOG(4) << "handle TrtSkipLayerNorm fuse"; + + // x and y 's rank must be same + if (subgraph.at(x)->Var()->GetShape().size() != + subgraph.at(y)->Var()->GetShape().size()) { + return; + } + GET_IR_NODE_FROM_SUBGRAPH(elementwise, elementwise, fused_pattern); GET_IR_NODE_FROM_SUBGRAPH(elementwise_out, elementwise_out, fused_pattern); GET_IR_NODE_FROM_SUBGRAPH(layer_norm, layer_norm, fused_pattern); -- GitLab