未验证 提交 fa06d9c3 编写于 作者: W Wangzheee 提交者: GitHub

fix_multihead (#45429)

上级 a5e9ccda
...@@ -291,7 +291,7 @@ class MultiheadMatMulOpConverter : public OpConverter { ...@@ -291,7 +291,7 @@ class MultiheadMatMulOpConverter : public OpConverter {
plugin_inputs.data(), plugin_inputs.size(), *plugin); plugin_inputs.data(), plugin_inputs.size(), *plugin);
layer = plugin_layer; layer = plugin_layer;
} }
} } else {
if (input_dims.d[1] <= 384 && !bias_qk_attr && if (input_dims.d[1] <= 384 && !bias_qk_attr &&
engine_->precision() != AnalysisConfig::Precision::kFloat32) { engine_->precision() != AnalysisConfig::Precision::kFloat32) {
/* /*
...@@ -392,12 +392,14 @@ class MultiheadMatMulOpConverter : public OpConverter { ...@@ -392,12 +392,14 @@ class MultiheadMatMulOpConverter : public OpConverter {
reshape_before_fc_layer->setInput( reshape_before_fc_layer->setInput(
1, *Concat(reshape_before_fc_shape_tensor)); 1, *Concat(reshape_before_fc_shape_tensor));
reshape_before_fc_layer->setName( reshape_before_fc_layer->setName(
("shuffle_before_fc_multihead_matmul(Output: " + output_name + ")") ("shuffle_before_fc_multihead_matmul(Output: " + output_name +
")")
.c_str()); .c_str());
// add fc layer // add fc layer
nvinfer1::ILayer* fc_layer = nullptr; nvinfer1::ILayer* fc_layer = nullptr;
fc_layer = TRT_ENGINE_ADD_LAYER(engine_, fc_layer =
TRT_ENGINE_ADD_LAYER(engine_,
FullyConnected, FullyConnected,
*reshape_before_fc_layer->getOutput(0), *reshape_before_fc_layer->getOutput(0),
n, n,
...@@ -427,14 +429,20 @@ class MultiheadMatMulOpConverter : public OpConverter { ...@@ -427,14 +429,20 @@ class MultiheadMatMulOpConverter : public OpConverter {
int var_seqlen = 1; int var_seqlen = 1;
bool has_mask = true; bool has_mask = true;
std::vector<nvinfer1::PluginField> fields{ std::vector<nvinfer1::PluginField> fields{
{"hidden_size", &hidden_out, nvinfer1::PluginFieldType::kINT32, 1}, {"hidden_size",
&hidden_out,
nvinfer1::PluginFieldType::kINT32,
1},
{"num_heads", &head_number, nvinfer1::PluginFieldType::kINT32, 1}, {"num_heads", &head_number, nvinfer1::PluginFieldType::kINT32, 1},
{"type_id", &type, nvinfer1::PluginFieldType::kINT32, 1}, {"type_id", &type, nvinfer1::PluginFieldType::kINT32, 1},
{"has_mask", &has_mask, nvinfer1::PluginFieldType::kINT32, 1}, {"has_mask", &has_mask, nvinfer1::PluginFieldType::kINT32, 1},
{"var_seqlen", &var_seqlen, nvinfer1::PluginFieldType::kINT32, 1}}; {"var_seqlen",
&var_seqlen,
nvinfer1::PluginFieldType::kINT32,
1}};
nvinfer1::PluginFieldCollection* plugin_collection = nvinfer1::PluginFieldCollection* plugin_collection =
static_cast<nvinfer1::PluginFieldCollection*>( static_cast<nvinfer1::PluginFieldCollection*>(malloc(
malloc(sizeof(*plugin_collection) + sizeof(*plugin_collection) +
fields.size() * fields.size() *
sizeof(nvinfer1::PluginField))); // remember to free sizeof(nvinfer1::PluginField))); // remember to free
plugin_collection->nbFields = static_cast<int>(fields.size()); plugin_collection->nbFields = static_cast<int>(fields.size());
...@@ -506,8 +514,8 @@ class MultiheadMatMulOpConverter : public OpConverter { ...@@ -506,8 +514,8 @@ class MultiheadMatMulOpConverter : public OpConverter {
plugin_inputs.data(), plugin_inputs.size(), *plugin); plugin_inputs.data(), plugin_inputs.size(), *plugin);
// add shuffle // add shuffle
auto* reshape_after_mha_layer = auto* reshape_after_mha_layer = TRT_ENGINE_ADD_LAYER(
TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *plugin_layer->getOutput(0)); engine_, Shuffle, *plugin_layer->getOutput(0));
std::vector<nvinfer1::ITensor*> reshape_tensor; std::vector<nvinfer1::ITensor*> reshape_tensor;
reshape_tensor.push_back(batch_tensor); reshape_tensor.push_back(batch_tensor);
reshape_tensor.push_back(length_tensor); reshape_tensor.push_back(length_tensor);
...@@ -554,8 +562,8 @@ class MultiheadMatMulOpConverter : public OpConverter { ...@@ -554,8 +562,8 @@ class MultiheadMatMulOpConverter : public OpConverter {
auto* reshape_before_fc_layer = auto* reshape_before_fc_layer =
TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *input); TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *input);
if (op_desc.HasAttr("Input_scale")) { if (op_desc.HasAttr("Input_scale")) {
engine_->SetTensorDynamicRange(reshape_before_fc_layer->getOutput(0), engine_->SetTensorDynamicRange(
in_scale); reshape_before_fc_layer->getOutput(0), in_scale);
} }
reshape_before_fc_layer->setInput( reshape_before_fc_layer->setInput(
1, *Concat(reshape_before_fc_shape_tensor)); 1, *Concat(reshape_before_fc_shape_tensor));
...@@ -586,11 +594,11 @@ class MultiheadMatMulOpConverter : public OpConverter { ...@@ -586,11 +594,11 @@ class MultiheadMatMulOpConverter : public OpConverter {
} }
if (op_desc.HasAttr("fc_out_threshold")) { if (op_desc.HasAttr("fc_out_threshold")) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(op_desc.HasAttr("fc_out_threshold"),
op_desc.HasAttr("fc_out_threshold"),
true, true,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"must have out threshold in multihead layers in int8 mode")); "must have out threshold in multihead layers "
"in int8 mode"));
float out_scale = float out_scale =
PADDLE_GET_CONST(float, op_desc.GetAttr("fc_out_threshold")); PADDLE_GET_CONST(float, op_desc.GetAttr("fc_out_threshold"));
engine_->SetTensorDynamicRange(fc_layer->getOutput(0), out_scale); engine_->SetTensorDynamicRange(fc_layer->getOutput(0), out_scale);
...@@ -619,6 +627,7 @@ class MultiheadMatMulOpConverter : public OpConverter { ...@@ -619,6 +627,7 @@ class MultiheadMatMulOpConverter : public OpConverter {
hidden_in, head_number, head_size, scale, with_fp16); hidden_in, head_number, head_size, scale, with_fp16);
layer = engine_->AddDynamicPlugin(plugin_inputs.data(), 2, plugin); layer = engine_->AddDynamicPlugin(plugin_inputs.data(), 2, plugin);
} }
}
} else { } else {
PADDLE_THROW(platform::errors::Fatal( PADDLE_THROW(platform::errors::Fatal(
"You are running the Ernie(Bert) model in static shape mode, which " "You are running the Ernie(Bert) model in static shape mode, which "
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册