未验证 提交 439b2b94 编写于 作者: W Wangzheee 提交者: GitHub

fix embedding multihead (#49085)

上级 e577040e
...@@ -181,7 +181,7 @@ class EmbEltwiseLayerNormOpConverter : public OpConverter { ...@@ -181,7 +181,7 @@ class EmbEltwiseLayerNormOpConverter : public OpConverter {
layer = plugin_layer; layer = plugin_layer;
auto output_name = op_desc.Output("Out")[0]; auto output_name = op_desc.Output("Out")[0];
RreplenishLayerAndOutput(layer, RreplenishLayerAndOutput(layer,
"ManyEmbLayerNormPluginDynamic_V1", "ManyEmbLayerNormVarlenPluginDynamicV1",
{output_name, {output_name,
std::string("qkv_plugin_mask"), std::string("qkv_plugin_mask"),
std::string("max_seqlen_tensor")}, std::string("max_seqlen_tensor")},
......
...@@ -257,7 +257,8 @@ class MultiheadMatMulOpConverter : public OpConverter { ...@@ -257,7 +257,8 @@ class MultiheadMatMulOpConverter : public OpConverter {
max_seqlen_tensor); // max_seqlen, eval_placeholder_3 max_seqlen_tensor); // max_seqlen, eval_placeholder_3
auto plugin_layer = engine_->network()->addPluginV2( auto plugin_layer = engine_->network()->addPluginV2(
plugin_inputs.data(), plugin_inputs.size(), *plugin); plugin_inputs.data(), plugin_inputs.size(), *plugin);
layer = plugin_layer; RreplenishLayerAndOutput(
plugin_layer, "multihead_matmul", {output_name}, test_mode);
} else { } else {
int head_size = hidden_out / head_number; int head_size = hidden_out / head_number;
// [3, head_number, head_size, hidden_in] -> [head_number, 3, // [3, head_number, head_size, hidden_in] -> [head_number, 3,
...@@ -381,7 +382,8 @@ class MultiheadMatMulOpConverter : public OpConverter { ...@@ -381,7 +382,8 @@ class MultiheadMatMulOpConverter : public OpConverter {
auto plugin_layer = engine_->network()->addPluginV2( auto plugin_layer = engine_->network()->addPluginV2(
plugin_inputs.data(), plugin_inputs.size(), *plugin); plugin_inputs.data(), plugin_inputs.size(), *plugin);
plugin_layer->setName(
("CustomQKVToContextPluginDynamic: " + output_name).c_str());
// recover no_varlen output // recover no_varlen output
if (!flag_varseqlen) { if (!flag_varseqlen) {
std::vector<nvinfer1::ITensor*> output_transformer; std::vector<nvinfer1::ITensor*> output_transformer;
...@@ -394,7 +396,10 @@ class MultiheadMatMulOpConverter : public OpConverter { ...@@ -394,7 +396,10 @@ class MultiheadMatMulOpConverter : public OpConverter {
engine_->AddDynamicPlugin(output_transformer.data(), engine_->AddDynamicPlugin(output_transformer.data(),
output_transformer.size(), output_transformer.size(),
plugin); plugin);
layer = transformer_output_layer; engine_->SetITensor(output_name,
transformer_output_layer->getOutput(0));
} else {
engine_->SetITensor(output_name, plugin_layer->getOutput(0));
} }
} }
} else { } else {
...@@ -776,6 +781,8 @@ class MultiheadMatMulOpConverter : public OpConverter { ...@@ -776,6 +781,8 @@ class MultiheadMatMulOpConverter : public OpConverter {
new plugin::QkvToContextPluginDynamic( new plugin::QkvToContextPluginDynamic(
hidden_in, head_number, head_size, scale, with_fp16); hidden_in, head_number, head_size, scale, with_fp16);
layer = engine_->AddDynamicPlugin(plugin_inputs.data(), 2, plugin); layer = engine_->AddDynamicPlugin(plugin_inputs.data(), 2, plugin);
RreplenishLayerAndOutput(
layer, "multihead_matmul", {output_name}, test_mode);
} }
} }
} else { } else {
...@@ -785,8 +792,6 @@ class MultiheadMatMulOpConverter : public OpConverter { ...@@ -785,8 +792,6 @@ class MultiheadMatMulOpConverter : public OpConverter {
"You can use the config.SetTRTDynamicShapeInfo(...) interface to set " "You can use the config.SetTRTDynamicShapeInfo(...) interface to set "
"the shape information to run the dynamic shape mode.")); "the shape information to run the dynamic shape mode."));
} }
RreplenishLayerAndOutput(
layer, "multihead_matmul", {output_name}, test_mode);
} }
}; };
......
...@@ -255,7 +255,7 @@ bool EmbLayerNormVarSeqlenPluginBase::supportsFormatCombination( ...@@ -255,7 +255,7 @@ bool EmbLayerNormVarSeqlenPluginBase::supportsFormatCombination(
desc.dims.d[0] == prev.dims.d[0]; desc.dims.d[0] == prev.dims.d[0];
} }
if (pos == nbInputs - 1) { // mask id if (pos == nbInputs - 1) { // mask id
return desc.type == prev.type; return desc.type == mType;
} }
// embedded sequence // embedded sequence
if (pos == nbInputs) { if (pos == nbInputs) {
...@@ -265,11 +265,11 @@ bool EmbLayerNormVarSeqlenPluginBase::supportsFormatCombination( ...@@ -265,11 +265,11 @@ bool EmbLayerNormVarSeqlenPluginBase::supportsFormatCombination(
} }
// mask(HFace) or pre_layernorm_bias(MTron) // mask(HFace) or pre_layernorm_bias(MTron)
if (pos == nbInputs + 1) { if (pos == nbInputs + 1) {
return desc.type == prev.type; return desc.type == mType;
} }
// max seqlen // max seqlen
if (pos == nbInputs + 2) { if (pos == nbInputs + 2) {
return desc.type == prev.type; return desc.type == mType;
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册