未验证 提交 f57739be 编写于 作者: S Shang Zhizhou 提交者: GitHub

fix ernie_varlen when cutting head (#31497)

上级 45c7d905
...@@ -49,14 +49,14 @@ class MultiheadMatMulOpConverter : public OpConverter { ...@@ -49,14 +49,14 @@ class MultiheadMatMulOpConverter : public OpConverter {
memcpy(weight_data_tmp.data(), weight_data, memcpy(weight_data_tmp.data(), weight_data,
weight_t->numel() * sizeof(float)); weight_t->numel() * sizeof(float));
// (hidden, 3, all_head_size) // (hidden_in, 3, hidden_out)
auto weight_dims = weight_t->dims(); auto weight_dims = weight_t->dims();
int hidden = weight_dims[0]; // channels_in int hidden_in = weight_dims[0]; // channels_in
int three = weight_dims[1]; // channels_out int three = weight_dims[1]; // channels_out
int all_head_size = weight_dims[2]; // channels_out int hidden_out = weight_dims[2]; // channels_out
int m = hidden; int m = hidden_in;
int n = three * all_head_size; int n = three * hidden_out;
auto tranpose_weight = [](const float* src, float* dst, int m, int n) { auto tranpose_weight = [](const float* src, float* dst, int m, int n) {
for (int i = 0; i < m; i++) { for (int i = 0; i < m; i++) {
for (int j = 0; j < n; j++) { for (int j = 0; j < n; j++) {
...@@ -72,21 +72,23 @@ class MultiheadMatMulOpConverter : public OpConverter { ...@@ -72,21 +72,23 @@ class MultiheadMatMulOpConverter : public OpConverter {
if (engine_->with_dynamic_shape()) { if (engine_->with_dynamic_shape()) {
if (engine_->use_oss()) { if (engine_->use_oss()) {
int head_size = hidden / head_number; int head_size = hidden_out / head_number;
// [3, Nout, Hout, Nin, Hin] -> [Nout, 3, Hout, Nin, Hin] // [3, head_number, head_size, hidden_in] -> [head_number, 3, head_size,
auto transpose_weight_v2 = [](const float* src, float* dst, int N, // hidden_in]
int H) { auto transpose_weight_v2 = [](const float* src, float* dst, int three,
const int HNH = H * N * H; int head_number, int head_size,
for (int i = 0; i < 3; ++i) { int hidden_in) {
for (int n = 0; n < N; ++n) { const int HH = head_size * hidden_in;
for (int hnh = 0; hnh < HNH; ++hnh) { for (int i = 0; i < three; ++i) {
dst[n * 3 * HNH + i * HNH + hnh] = for (int n = 0; n < head_number; ++n) {
src[i * N * HNH + n * HNH + hnh]; for (int hh = 0; hh < HH; ++hh) {
dst[n * three * HH + i * HH + hh] =
src[i * head_number * HH + n * HH + hh];
} }
} }
} }
}; };
// [3, N, H] -> [N, 3, H] // [3, head_number, head_size] -> [head_number, 3, head_size]
auto transpose_bias_v2 = [](const float* src, float* dst, int N, auto transpose_bias_v2 = [](const float* src, float* dst, int N,
int H) { int H) {
for (int i = 0; i < 3; ++i) { for (int i = 0; i < 3; ++i) {
...@@ -99,8 +101,8 @@ class MultiheadMatMulOpConverter : public OpConverter { ...@@ -99,8 +101,8 @@ class MultiheadMatMulOpConverter : public OpConverter {
}; };
memcpy(weight_data_tmp.data(), weight_data, memcpy(weight_data_tmp.data(), weight_data,
weight_t->numel() * sizeof(float)); weight_t->numel() * sizeof(float));
transpose_weight_v2(weight_data_tmp.data(), weight_data, head_number, transpose_weight_v2(weight_data_tmp.data(), weight_data, three,
head_size); head_number, head_size, hidden_in);
nvinfer1::Weights weight{nvinfer1::DataType::kFLOAT, nvinfer1::Weights weight{nvinfer1::DataType::kFLOAT,
static_cast<void*>(weight_data), static_cast<void*>(weight_data),
static_cast<int32_t>(weight_t->numel())}; static_cast<int32_t>(weight_t->numel())};
...@@ -130,7 +132,7 @@ class MultiheadMatMulOpConverter : public OpConverter { ...@@ -130,7 +132,7 @@ class MultiheadMatMulOpConverter : public OpConverter {
int var_seqlen = 1; int var_seqlen = 1;
const std::vector<nvinfer1::PluginField> fields{ const std::vector<nvinfer1::PluginField> fields{
{"type_id", &type, nvinfer1::PluginFieldType::kINT32, 1}, {"type_id", &type, nvinfer1::PluginFieldType::kINT32, 1},
{"hidden_size", &hidden, nvinfer1::PluginFieldType::kINT32, 1}, {"hidden_size", &hidden_out, nvinfer1::PluginFieldType::kINT32, 1},
{"num_heads", &head_number, nvinfer1::PluginFieldType::kINT32, 1}, {"num_heads", &head_number, nvinfer1::PluginFieldType::kINT32, 1},
{"has_mask", &has_mask, nvinfer1::PluginFieldType::kINT32, 1}, {"has_mask", &has_mask, nvinfer1::PluginFieldType::kINT32, 1},
{"var_seqlen", &var_seqlen, nvinfer1::PluginFieldType::kINT32, 1}, {"var_seqlen", &var_seqlen, nvinfer1::PluginFieldType::kINT32, 1},
...@@ -186,7 +188,7 @@ class MultiheadMatMulOpConverter : public OpConverter { ...@@ -186,7 +188,7 @@ class MultiheadMatMulOpConverter : public OpConverter {
n, weight.get(), bias.get()); n, weight.get(), bias.get());
auto* fc_out = fc_layer->getOutput(0); auto* fc_out = fc_layer->getOutput(0);
// add qkv to context // add qkv to context
int head_size = all_head_size / head_number; int head_size = hidden_out / head_number;
float scale = BOOST_GET_CONST(float, op_desc.GetAttr("alpha")); float scale = BOOST_GET_CONST(float, op_desc.GetAttr("alpha"));
std::vector<nvinfer1::ITensor*> plugin_inputs; std::vector<nvinfer1::ITensor*> plugin_inputs;
...@@ -195,7 +197,7 @@ class MultiheadMatMulOpConverter : public OpConverter { ...@@ -195,7 +197,7 @@ class MultiheadMatMulOpConverter : public OpConverter {
bool with_fp16 = bool with_fp16 =
engine_->WithFp16() && !engine_->disable_trt_plugin_fp16(); engine_->WithFp16() && !engine_->disable_trt_plugin_fp16();
plugin::DynamicPluginTensorRT* plugin = plugin::DynamicPluginTensorRT* plugin =
new plugin::QkvToContextPluginDynamic(hidden, head_number, new plugin::QkvToContextPluginDynamic(hidden_in, head_number,
head_size, scale, with_fp16); head_size, scale, with_fp16);
layer = engine_->AddPluginV2(plugin_inputs.data(), 2, plugin); layer = engine_->AddPluginV2(plugin_inputs.data(), 2, plugin);
} }
......
...@@ -54,7 +54,12 @@ nvinfer1::DimsExprs SpecialSlicePluginDynamic::getOutputDimensions( ...@@ -54,7 +54,12 @@ nvinfer1::DimsExprs SpecialSlicePluginDynamic::getOutputDimensions(
int output_index, const nvinfer1::DimsExprs* inputs, int nb_inputs, int output_index, const nvinfer1::DimsExprs* inputs, int nb_inputs,
nvinfer1::IExprBuilder& expr_builder) { nvinfer1::IExprBuilder& expr_builder) {
nvinfer1::DimsExprs output(inputs[0]); nvinfer1::DimsExprs output(inputs[0]);
output.nbDims++;
for (int i = output.nbDims - 1; i > 1; i--) {
output.d[i] = inputs[0].d[i - 1];
}
auto one = expr_builder.constant(1); auto one = expr_builder.constant(1);
output.d[1] = one;
output.d[0] = expr_builder.operation(nvinfer1::DimensionOperation::kSUB, output.d[0] = expr_builder.operation(nvinfer1::DimensionOperation::kSUB,
*inputs[1].d[0], *one); *inputs[1].d[0], *one);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册