From 28c36d8682dddc38bc91ff38b3d644a0fc42102e Mon Sep 17 00:00:00 2001 From: Shang Zhizhou Date: Thu, 11 Mar 2021 15:19:45 +0800 Subject: [PATCH] fix ernie_varlen when cutting head (#31497) (#31512) --- .../tensorrt/convert/multihead_matmul_op.cc | 46 ++++++++++--------- .../tensorrt/plugin/special_slice_plugin.cu | 5 ++ 2 files changed, 29 insertions(+), 22 deletions(-) diff --git a/paddle/fluid/inference/tensorrt/convert/multihead_matmul_op.cc b/paddle/fluid/inference/tensorrt/convert/multihead_matmul_op.cc index 0beced7363..82c01490d9 100644 --- a/paddle/fluid/inference/tensorrt/convert/multihead_matmul_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/multihead_matmul_op.cc @@ -49,14 +49,14 @@ class MultiheadMatMulOpConverter : public OpConverter { memcpy(weight_data_tmp.data(), weight_data, weight_t->numel() * sizeof(float)); - // (hidden, 3, all_head_size) + // (hidden_in, 3, hidden_out) auto weight_dims = weight_t->dims(); - int hidden = weight_dims[0]; // channels_in - int three = weight_dims[1]; // channels_out - int all_head_size = weight_dims[2]; // channels_out - int m = hidden; - int n = three * all_head_size; + int hidden_in = weight_dims[0]; // channels_in + int three = weight_dims[1]; // channels_out + int hidden_out = weight_dims[2]; // channels_out + int m = hidden_in; + int n = three * hidden_out; auto tranpose_weight = [](const float* src, float* dst, int m, int n) { for (int i = 0; i < m; i++) { for (int j = 0; j < n; j++) { @@ -72,21 +72,23 @@ class MultiheadMatMulOpConverter : public OpConverter { if (engine_->with_dynamic_shape()) { if (engine_->use_oss()) { - int head_size = hidden / head_number; - // [3, Nout, Hout, Nin, Hin] -> [Nout, 3, Hout, Nin, Hin] - auto transpose_weight_v2 = [](const float* src, float* dst, int N, - int H) { - const int HNH = H * N * H; - for (int i = 0; i < 3; ++i) { - for (int n = 0; n < N; ++n) { - for (int hnh = 0; hnh < HNH; ++hnh) { - dst[n * 3 * HNH + i * HNH + hnh] = - src[i * N * HNH + n * HNH + hnh]; + int head_size = hidden_out / head_number; + // [3, head_number, head_size, hidden_in] -> [head_number, 3, head_size, + // hidden_in] + auto transpose_weight_v2 = [](const float* src, float* dst, int three, + int head_number, int head_size, + int hidden_in) { + const int HH = head_size * hidden_in; + for (int i = 0; i < three; ++i) { + for (int n = 0; n < head_number; ++n) { + for (int hh = 0; hh < HH; ++hh) { + dst[n * three * HH + i * HH + hh] = + src[i * head_number * HH + n * HH + hh]; } } } }; - // [3, N, H] -> [N, 3, H] + // [3, head_number, head_size] -> [head_number, 3, head_size] auto transpose_bias_v2 = [](const float* src, float* dst, int N, int H) { for (int i = 0; i < 3; ++i) { @@ -99,8 +101,8 @@ class MultiheadMatMulOpConverter : public OpConverter { }; memcpy(weight_data_tmp.data(), weight_data, weight_t->numel() * sizeof(float)); - transpose_weight_v2(weight_data_tmp.data(), weight_data, head_number, - head_size); + transpose_weight_v2(weight_data_tmp.data(), weight_data, three, + head_number, head_size, hidden_in); nvinfer1::Weights weight{nvinfer1::DataType::kFLOAT, static_cast(weight_data), static_cast(weight_t->numel())}; @@ -130,7 +132,7 @@ class MultiheadMatMulOpConverter : public OpConverter { int var_seqlen = 1; const std::vector fields{ {"type_id", &type, nvinfer1::PluginFieldType::kINT32, 1}, - {"hidden_size", &hidden, nvinfer1::PluginFieldType::kINT32, 1}, + {"hidden_size", &hidden_out, nvinfer1::PluginFieldType::kINT32, 1}, {"num_heads", &head_number, nvinfer1::PluginFieldType::kINT32, 1}, {"has_mask", &has_mask, nvinfer1::PluginFieldType::kINT32, 1}, {"var_seqlen", &var_seqlen, nvinfer1::PluginFieldType::kINT32, 1}, @@ -185,7 +187,7 @@ class MultiheadMatMulOpConverter : public OpConverter { n, weight.get(), bias.get()); auto* fc_out = fc_layer->getOutput(0); // add qkv to context - int head_size = all_head_size / head_number; + int head_size = hidden_out / head_number; float scale = boost::get(op_desc.GetAttr("alpha")); std::vector plugin_inputs; @@ -194,7 +196,7 @@ class MultiheadMatMulOpConverter : public OpConverter { bool with_fp16 = engine_->WithFp16() && !engine_->disable_trt_plugin_fp16(); plugin::DynamicPluginTensorRT* plugin = - new plugin::QkvToContextPluginDynamic(hidden, head_number, + new plugin::QkvToContextPluginDynamic(hidden_in, head_number, head_size, scale, with_fp16); layer = engine_->AddPluginV2(plugin_inputs.data(), 2, plugin); } diff --git a/paddle/fluid/inference/tensorrt/plugin/special_slice_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/special_slice_plugin.cu index ed0a530439..250b944652 100644 --- a/paddle/fluid/inference/tensorrt/plugin/special_slice_plugin.cu +++ b/paddle/fluid/inference/tensorrt/plugin/special_slice_plugin.cu @@ -54,7 +54,12 @@ nvinfer1::DimsExprs SpecialSlicePluginDynamic::getOutputDimensions( int output_index, const nvinfer1::DimsExprs* inputs, int nb_inputs, nvinfer1::IExprBuilder& expr_builder) { nvinfer1::DimsExprs output(inputs[0]); + output.nbDims++; + for (int i = output.nbDims - 1; i > 1; i--) { + output.d[i] = inputs[0].d[i - 1]; + } auto one = expr_builder.constant(1); + output.d[1] = one; output.d[0] = expr_builder.operation(nvinfer1::DimensionOperation::kSUB, *inputs[1].d[0], *one); -- GitLab