未验证 提交 439b2b94 编写于 作者: W Wangzheee 提交者: GitHub

fix embedding multihead (#49085)

上级 e577040e
......@@ -181,7 +181,7 @@ class EmbEltwiseLayerNormOpConverter : public OpConverter {
layer = plugin_layer;
auto output_name = op_desc.Output("Out")[0];
RreplenishLayerAndOutput(layer,
"ManyEmbLayerNormPluginDynamic_V1",
"ManyEmbLayerNormVarlenPluginDynamicV1",
{output_name,
std::string("qkv_plugin_mask"),
std::string("max_seqlen_tensor")},
......
......@@ -257,7 +257,8 @@ class MultiheadMatMulOpConverter : public OpConverter {
max_seqlen_tensor); // max_seqlen, eval_placeholder_3
auto plugin_layer = engine_->network()->addPluginV2(
plugin_inputs.data(), plugin_inputs.size(), *plugin);
layer = plugin_layer;
RreplenishLayerAndOutput(
plugin_layer, "multihead_matmul", {output_name}, test_mode);
} else {
int head_size = hidden_out / head_number;
// [3, head_number, head_size, hidden_in] -> [head_number, 3,
......@@ -381,7 +382,8 @@ class MultiheadMatMulOpConverter : public OpConverter {
auto plugin_layer = engine_->network()->addPluginV2(
plugin_inputs.data(), plugin_inputs.size(), *plugin);
plugin_layer->setName(
("CustomQKVToContextPluginDynamic: " + output_name).c_str());
// recover no_varlen output
if (!flag_varseqlen) {
std::vector<nvinfer1::ITensor*> output_transformer;
......@@ -394,7 +396,10 @@ class MultiheadMatMulOpConverter : public OpConverter {
engine_->AddDynamicPlugin(output_transformer.data(),
output_transformer.size(),
plugin);
layer = transformer_output_layer;
engine_->SetITensor(output_name,
transformer_output_layer->getOutput(0));
} else {
engine_->SetITensor(output_name, plugin_layer->getOutput(0));
}
}
} else {
......@@ -776,6 +781,8 @@ class MultiheadMatMulOpConverter : public OpConverter {
new plugin::QkvToContextPluginDynamic(
hidden_in, head_number, head_size, scale, with_fp16);
layer = engine_->AddDynamicPlugin(plugin_inputs.data(), 2, plugin);
RreplenishLayerAndOutput(
layer, "multihead_matmul", {output_name}, test_mode);
}
}
} else {
......@@ -785,8 +792,6 @@ class MultiheadMatMulOpConverter : public OpConverter {
"You can use the config.SetTRTDynamicShapeInfo(...) interface to set "
"the shape information to run the dynamic shape mode."));
}
RreplenishLayerAndOutput(
layer, "multihead_matmul", {output_name}, test_mode);
}
};
......
......@@ -255,7 +255,7 @@ bool EmbLayerNormVarSeqlenPluginBase::supportsFormatCombination(
desc.dims.d[0] == prev.dims.d[0];
}
if (pos == nbInputs - 1) { // mask id
return desc.type == prev.type;
return desc.type == mType;
}
// embedded sequence
if (pos == nbInputs) {
......@@ -265,11 +265,11 @@ bool EmbLayerNormVarSeqlenPluginBase::supportsFormatCombination(
}
// mask(HFace) or pre_layernorm_bias(MTron)
if (pos == nbInputs + 1) {
return desc.type == prev.type;
return desc.type == mType;
}
// max seqlen
if (pos == nbInputs + 2) {
return desc.type == prev.type;
return desc.type == mType;
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册