From 439b2b946478e41af5dd72399c867924aabbb03b Mon Sep 17 00:00:00 2001 From: Wangzheee <634486483@qq.com> Date: Thu, 15 Dec 2022 10:31:54 +0800 Subject: [PATCH] fix embedding multihead (#49085) --- .../tensorrt/convert/emb_eltwise_layernorm.cc | 2 +- .../tensorrt/convert/multihead_matmul_op.cc | 15 ++++++++++----- .../plugin/many_emb_layernorm_varseqlen_plugin.cu | 6 +++--- 3 files changed, 14 insertions(+), 9 deletions(-) diff --git a/paddle/fluid/inference/tensorrt/convert/emb_eltwise_layernorm.cc b/paddle/fluid/inference/tensorrt/convert/emb_eltwise_layernorm.cc index a70839ee9d4..fcf20848b07 100644 --- a/paddle/fluid/inference/tensorrt/convert/emb_eltwise_layernorm.cc +++ b/paddle/fluid/inference/tensorrt/convert/emb_eltwise_layernorm.cc @@ -181,7 +181,7 @@ class EmbEltwiseLayerNormOpConverter : public OpConverter { layer = plugin_layer; auto output_name = op_desc.Output("Out")[0]; RreplenishLayerAndOutput(layer, - "ManyEmbLayerNormPluginDynamic_V1", + "ManyEmbLayerNormVarlenPluginDynamicV1", {output_name, std::string("qkv_plugin_mask"), std::string("max_seqlen_tensor")}, diff --git a/paddle/fluid/inference/tensorrt/convert/multihead_matmul_op.cc b/paddle/fluid/inference/tensorrt/convert/multihead_matmul_op.cc index b3447bb23c3..f20dce42ace 100644 --- a/paddle/fluid/inference/tensorrt/convert/multihead_matmul_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/multihead_matmul_op.cc @@ -257,7 +257,8 @@ class MultiheadMatMulOpConverter : public OpConverter { max_seqlen_tensor); // max_seqlen, eval_placeholder_3 auto plugin_layer = engine_->network()->addPluginV2( plugin_inputs.data(), plugin_inputs.size(), *plugin); - layer = plugin_layer; + RreplenishLayerAndOutput( + plugin_layer, "multihead_matmul", {output_name}, test_mode); } else { int head_size = hidden_out / head_number; // [3, head_number, head_size, hidden_in] -> [head_number, 3, @@ -381,7 +382,8 @@ class MultiheadMatMulOpConverter : public OpConverter { auto plugin_layer = engine_->network()->addPluginV2( plugin_inputs.data(), plugin_inputs.size(), *plugin); - + plugin_layer->setName( + ("CustomQKVToContextPluginDynamic: " + output_name).c_str()); // recover no_varlen output if (!flag_varseqlen) { std::vector output_transformer; @@ -394,7 +396,10 @@ class MultiheadMatMulOpConverter : public OpConverter { engine_->AddDynamicPlugin(output_transformer.data(), output_transformer.size(), plugin); - layer = transformer_output_layer; + engine_->SetITensor(output_name, + transformer_output_layer->getOutput(0)); + } else { + engine_->SetITensor(output_name, plugin_layer->getOutput(0)); } } } else { @@ -776,6 +781,8 @@ class MultiheadMatMulOpConverter : public OpConverter { new plugin::QkvToContextPluginDynamic( hidden_in, head_number, head_size, scale, with_fp16); layer = engine_->AddDynamicPlugin(plugin_inputs.data(), 2, plugin); + RreplenishLayerAndOutput( + layer, "multihead_matmul", {output_name}, test_mode); } } } else { @@ -785,8 +792,6 @@ class MultiheadMatMulOpConverter : public OpConverter { "You can use the config.SetTRTDynamicShapeInfo(...) interface to set " "the shape information to run the dynamic shape mode.")); } - RreplenishLayerAndOutput( - layer, "multihead_matmul", {output_name}, test_mode); } }; diff --git a/paddle/fluid/inference/tensorrt/plugin/many_emb_layernorm_varseqlen_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/many_emb_layernorm_varseqlen_plugin.cu index 01abd3a6186..42d92f2bb48 100644 --- a/paddle/fluid/inference/tensorrt/plugin/many_emb_layernorm_varseqlen_plugin.cu +++ b/paddle/fluid/inference/tensorrt/plugin/many_emb_layernorm_varseqlen_plugin.cu @@ -255,7 +255,7 @@ bool EmbLayerNormVarSeqlenPluginBase::supportsFormatCombination( desc.dims.d[0] == prev.dims.d[0]; } if (pos == nbInputs - 1) { // mask id - return desc.type == prev.type; + return desc.type == mType; } // embedded sequence if (pos == nbInputs) { @@ -265,11 +265,11 @@ bool EmbLayerNormVarSeqlenPluginBase::supportsFormatCombination( } // mask(HFace) or pre_layernorm_bias(MTron) if (pos == nbInputs + 1) { - return desc.type == prev.type; + return desc.type == mType; } // max seqlen if (pos == nbInputs + 2) { - return desc.type == prev.type; + return desc.type == mType; } } -- GitLab