From 22511f3022896831f451469f9815951524409eec Mon Sep 17 00:00:00 2001 From: Wangzheee <634486483@qq.com> Date: Mon, 14 Feb 2022 08:47:32 +0000 Subject: [PATCH] support preln_ernie --- paddle/fluid/framework/ir/pass.h | 2 ++ .../ir/preln_embedding_eltwise_layernorm_fuse_pass.cc | 2 +- paddle/fluid/framework/ir/preln_skip_layernorm_fuse_pass.cc | 6 +++--- .../inference/analysis/ir_passes/tensorrt_subgraph_pass.cc | 6 ++++-- 4 files changed, 10 insertions(+), 6 deletions(-) diff --git a/paddle/fluid/framework/ir/pass.h b/paddle/fluid/framework/ir/pass.h index 016d0fd4a66..acfe8d53cea 100644 --- a/paddle/fluid/framework/ir/pass.h +++ b/paddle/fluid/framework/ir/pass.h @@ -47,6 +47,8 @@ constexpr char kPassRecorder[] = "pass_recorder"; constexpr char kEmbEltwiseLayernormPass[] = "embedding_eltwise_layernorm_fuse_pass_flag"; constexpr char kMultiheadMatmulPass[] = "multihead_matmul_fuse_pass_flag"; +constexpr char kPrelnEmbEltwiseLayernormPass[] = + "preln_embedding_eltwise_layernorm_fuse_pass_flag"; class Pass { public: diff --git a/paddle/fluid/framework/ir/preln_embedding_eltwise_layernorm_fuse_pass.cc b/paddle/fluid/framework/ir/preln_embedding_eltwise_layernorm_fuse_pass.cc index ba7f2a4ce1d..ca42a613411 100644 --- a/paddle/fluid/framework/ir/preln_embedding_eltwise_layernorm_fuse_pass.cc +++ b/paddle/fluid/framework/ir/preln_embedding_eltwise_layernorm_fuse_pass.cc @@ -364,7 +364,7 @@ int PrelnEmbeddingEltwiseLayerNormFusePass::BuildFusion( IR_NODE_LINK_TO(end_pattern_biases[k], preln_embedding_eltwise_layernorm); IR_NODE_LINK_TO(end_pattern_scales[k], preln_embedding_eltwise_layernorm); IR_NODE_LINK_TO(preln_embedding_eltwise_layernorm, end_pattern_out[k]); - IR_NODE_LINK_TO(embedding_eltwise_layernorm, inner_pattern_out[k]); + IR_NODE_LINK_TO(preln_embedding_eltwise_layernorm, inner_pattern_out[k]); // Remove unneeded nodes. std::unordered_set marked_nodes; diff --git a/paddle/fluid/framework/ir/preln_skip_layernorm_fuse_pass.cc b/paddle/fluid/framework/ir/preln_skip_layernorm_fuse_pass.cc index 88983fd1993..1b7b82cbca9 100644 --- a/paddle/fluid/framework/ir/preln_skip_layernorm_fuse_pass.cc +++ b/paddle/fluid/framework/ir/preln_skip_layernorm_fuse_pass.cc @@ -53,7 +53,7 @@ struct PrelnSkipLayerNorm : public PatternBase { PATTERN_DECL_NODE(layer_norm_variance); }; -void *PrelnSkipLayerNorm::operator()(PDNode *x, PDNode *y) { +void PrelnSkipLayerNorm::operator()(PDNode *x, PDNode *y) { // Create nodes for elementwise add op. x->assert_is_op_input("elementwise_add", "X"); y->assert_is_op_input("elementwise_add", "Y"); @@ -61,8 +61,8 @@ void *PrelnSkipLayerNorm::operator()(PDNode *x, PDNode *y) { pattern->NewNode(elementwise_repr())->assert_is_op("elementwise_add"); auto *elementwise_out_var = pattern->NewNode(elementwise_out_repr()) ->assert_is_op_output("elementwise_add") - ->assert_is_op_input("layer_norm", "X"); - ->assert_is_op_input("elementwise_add", "Y"); + ->assert_is_op_input("layer_norm", "X") + ->assert_is_op_input("elementwise_add", "Y"); // Add links for elementwise_add op. elementwise->LinksFrom({x, y}).LinksTo({elementwise_out_var}); diff --git a/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc b/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc index 55bbc554508..00240e8790a 100644 --- a/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc +++ b/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc @@ -377,8 +377,10 @@ void TensorRtSubgraphPass::CreateTensorRTOp( trt_engine->SetDLACore(Get("trt_dla_core")); trt_engine->SetWithErnie( - graph->Has(framework::ir::kEmbEltwiseLayernormPass) && - graph->Has(framework::ir::kMultiheadMatmulPass)); + (graph->Has(framework::ir::kEmbEltwiseLayernormPass) && + graph->Has(framework::ir::kMultiheadMatmulPass)) || + (graph->Has(framework::ir::kPrelnEmbEltwiseLayernormPass) && + graph->Has(framework::ir::kMultiheadMatmulPass))); if (use_static_engine) { trt_engine_serialized_data = GetTrtEngineSerializedData( -- GitLab