diff --git a/paddle/fluid/framework/ir/pass.h b/paddle/fluid/framework/ir/pass.h index 016d0fd4a663ecfcc8d2b23ddb2a3af7b610b6cd..acfe8d53cea13cb5ac9797ea7d43311d01b9041b 100644 --- a/paddle/fluid/framework/ir/pass.h +++ b/paddle/fluid/framework/ir/pass.h @@ -47,6 +47,8 @@ constexpr char kPassRecorder[] = "pass_recorder"; constexpr char kEmbEltwiseLayernormPass[] = "embedding_eltwise_layernorm_fuse_pass_flag"; constexpr char kMultiheadMatmulPass[] = "multihead_matmul_fuse_pass_flag"; +constexpr char kPrelnEmbEltwiseLayernormPass[] = + "preln_embedding_eltwise_layernorm_fuse_pass_flag"; class Pass { public: diff --git a/paddle/fluid/framework/ir/preln_embedding_eltwise_layernorm_fuse_pass.cc b/paddle/fluid/framework/ir/preln_embedding_eltwise_layernorm_fuse_pass.cc index ba7f2a4ce1dc8becd5790ad64bea60ec4e85c33c..ca42a613411ba6078b00522d2c178856993fa462 100644 --- a/paddle/fluid/framework/ir/preln_embedding_eltwise_layernorm_fuse_pass.cc +++ b/paddle/fluid/framework/ir/preln_embedding_eltwise_layernorm_fuse_pass.cc @@ -364,7 +364,7 @@ int PrelnEmbeddingEltwiseLayerNormFusePass::BuildFusion( IR_NODE_LINK_TO(end_pattern_biases[k], preln_embedding_eltwise_layernorm); IR_NODE_LINK_TO(end_pattern_scales[k], preln_embedding_eltwise_layernorm); IR_NODE_LINK_TO(preln_embedding_eltwise_layernorm, end_pattern_out[k]); - IR_NODE_LINK_TO(embedding_eltwise_layernorm, inner_pattern_out[k]); + IR_NODE_LINK_TO(preln_embedding_eltwise_layernorm, inner_pattern_out[k]); // Remove unneeded nodes. std::unordered_set marked_nodes; diff --git a/paddle/fluid/framework/ir/preln_skip_layernorm_fuse_pass.cc b/paddle/fluid/framework/ir/preln_skip_layernorm_fuse_pass.cc index 88983fd199358ecc76097d0a2d4e03f34c829bf0..1b7b82cbca9e86587467fa0888eca6c6fdc2e162 100644 --- a/paddle/fluid/framework/ir/preln_skip_layernorm_fuse_pass.cc +++ b/paddle/fluid/framework/ir/preln_skip_layernorm_fuse_pass.cc @@ -53,7 +53,7 @@ struct PrelnSkipLayerNorm : public PatternBase { PATTERN_DECL_NODE(layer_norm_variance); }; -void *PrelnSkipLayerNorm::operator()(PDNode *x, PDNode *y) { +void PrelnSkipLayerNorm::operator()(PDNode *x, PDNode *y) { // Create nodes for elementwise add op. x->assert_is_op_input("elementwise_add", "X"); y->assert_is_op_input("elementwise_add", "Y"); @@ -61,8 +61,8 @@ void *PrelnSkipLayerNorm::operator()(PDNode *x, PDNode *y) { pattern->NewNode(elementwise_repr())->assert_is_op("elementwise_add"); auto *elementwise_out_var = pattern->NewNode(elementwise_out_repr()) ->assert_is_op_output("elementwise_add") - ->assert_is_op_input("layer_norm", "X"); - ->assert_is_op_input("elementwise_add", "Y"); + ->assert_is_op_input("layer_norm", "X") + ->assert_is_op_input("elementwise_add", "Y"); // Add links for elementwise_add op. elementwise->LinksFrom({x, y}).LinksTo({elementwise_out_var}); diff --git a/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc b/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc index 55bbc55450876c47dc0affb27323dbf397cc5c6c..00240e8790a79fa00c8ac1740ece005ff022971f 100644 --- a/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc +++ b/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc @@ -377,8 +377,10 @@ void TensorRtSubgraphPass::CreateTensorRTOp( trt_engine->SetDLACore(Get("trt_dla_core")); trt_engine->SetWithErnie( - graph->Has(framework::ir::kEmbEltwiseLayernormPass) && - graph->Has(framework::ir::kMultiheadMatmulPass)); + (graph->Has(framework::ir::kEmbEltwiseLayernormPass) && + graph->Has(framework::ir::kMultiheadMatmulPass)) || + (graph->Has(framework::ir::kPrelnEmbEltwiseLayernormPass) && + graph->Has(framework::ir::kMultiheadMatmulPass))); if (use_static_engine) { trt_engine_serialized_data = GetTrtEngineSerializedData(