未验证 提交 fa06d9c3 编写于 作者: W Wangzheee 提交者: GitHub

fix_multihead (#45429)

上级 a5e9ccda
......@@ -291,7 +291,7 @@ class MultiheadMatMulOpConverter : public OpConverter {
plugin_inputs.data(), plugin_inputs.size(), *plugin);
layer = plugin_layer;
}
}
} else {
if (input_dims.d[1] <= 384 && !bias_qk_attr &&
engine_->precision() != AnalysisConfig::Precision::kFloat32) {
/*
......@@ -392,12 +392,14 @@ class MultiheadMatMulOpConverter : public OpConverter {
reshape_before_fc_layer->setInput(
1, *Concat(reshape_before_fc_shape_tensor));
reshape_before_fc_layer->setName(
("shuffle_before_fc_multihead_matmul(Output: " + output_name + ")")
("shuffle_before_fc_multihead_matmul(Output: " + output_name +
")")
.c_str());
// add fc layer
nvinfer1::ILayer* fc_layer = nullptr;
fc_layer = TRT_ENGINE_ADD_LAYER(engine_,
fc_layer =
TRT_ENGINE_ADD_LAYER(engine_,
FullyConnected,
*reshape_before_fc_layer->getOutput(0),
n,
......@@ -427,14 +429,20 @@ class MultiheadMatMulOpConverter : public OpConverter {
int var_seqlen = 1;
bool has_mask = true;
std::vector<nvinfer1::PluginField> fields{
{"hidden_size", &hidden_out, nvinfer1::PluginFieldType::kINT32, 1},
{"hidden_size",
&hidden_out,
nvinfer1::PluginFieldType::kINT32,
1},
{"num_heads", &head_number, nvinfer1::PluginFieldType::kINT32, 1},
{"type_id", &type, nvinfer1::PluginFieldType::kINT32, 1},
{"has_mask", &has_mask, nvinfer1::PluginFieldType::kINT32, 1},
{"var_seqlen", &var_seqlen, nvinfer1::PluginFieldType::kINT32, 1}};
{"var_seqlen",
&var_seqlen,
nvinfer1::PluginFieldType::kINT32,
1}};
nvinfer1::PluginFieldCollection* plugin_collection =
static_cast<nvinfer1::PluginFieldCollection*>(
malloc(sizeof(*plugin_collection) +
static_cast<nvinfer1::PluginFieldCollection*>(malloc(
sizeof(*plugin_collection) +
fields.size() *
sizeof(nvinfer1::PluginField))); // remember to free
plugin_collection->nbFields = static_cast<int>(fields.size());
......@@ -506,8 +514,8 @@ class MultiheadMatMulOpConverter : public OpConverter {
plugin_inputs.data(), plugin_inputs.size(), *plugin);
// add shuffle
auto* reshape_after_mha_layer =
TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *plugin_layer->getOutput(0));
auto* reshape_after_mha_layer = TRT_ENGINE_ADD_LAYER(
engine_, Shuffle, *plugin_layer->getOutput(0));
std::vector<nvinfer1::ITensor*> reshape_tensor;
reshape_tensor.push_back(batch_tensor);
reshape_tensor.push_back(length_tensor);
......@@ -554,8 +562,8 @@ class MultiheadMatMulOpConverter : public OpConverter {
auto* reshape_before_fc_layer =
TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *input);
if (op_desc.HasAttr("Input_scale")) {
engine_->SetTensorDynamicRange(reshape_before_fc_layer->getOutput(0),
in_scale);
engine_->SetTensorDynamicRange(
reshape_before_fc_layer->getOutput(0), in_scale);
}
reshape_before_fc_layer->setInput(
1, *Concat(reshape_before_fc_shape_tensor));
......@@ -586,11 +594,11 @@ class MultiheadMatMulOpConverter : public OpConverter {
}
if (op_desc.HasAttr("fc_out_threshold")) {
PADDLE_ENFORCE_EQ(
op_desc.HasAttr("fc_out_threshold"),
PADDLE_ENFORCE_EQ(op_desc.HasAttr("fc_out_threshold"),
true,
platform::errors::InvalidArgument(
"must have out threshold in multihead layers in int8 mode"));
"must have out threshold in multihead layers "
"in int8 mode"));
float out_scale =
PADDLE_GET_CONST(float, op_desc.GetAttr("fc_out_threshold"));
engine_->SetTensorDynamicRange(fc_layer->getOutput(0), out_scale);
......@@ -619,6 +627,7 @@ class MultiheadMatMulOpConverter : public OpConverter {
hidden_in, head_number, head_size, scale, with_fp16);
layer = engine_->AddDynamicPlugin(plugin_inputs.data(), 2, plugin);
}
}
} else {
PADDLE_THROW(platform::errors::Fatal(
"You are running the Ernie(Bert) model in static shape mode, which "
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册