未验证 提交 28c36d86 编写于 作者: S Shang Zhizhou 提交者: GitHub

fix ernie_varlen when cutting head (#31497) (#31512)

上级 b54640bb
......@@ -49,14 +49,14 @@ class MultiheadMatMulOpConverter : public OpConverter {
memcpy(weight_data_tmp.data(), weight_data,
weight_t->numel() * sizeof(float));
// (hidden, 3, all_head_size)
// (hidden_in, 3, hidden_out)
auto weight_dims = weight_t->dims();
int hidden = weight_dims[0]; // channels_in
int three = weight_dims[1]; // channels_out
int all_head_size = weight_dims[2]; // channels_out
int m = hidden;
int n = three * all_head_size;
int hidden_in = weight_dims[0]; // channels_in
int three = weight_dims[1]; // channels_out
int hidden_out = weight_dims[2]; // channels_out
int m = hidden_in;
int n = three * hidden_out;
auto tranpose_weight = [](const float* src, float* dst, int m, int n) {
for (int i = 0; i < m; i++) {
for (int j = 0; j < n; j++) {
......@@ -72,21 +72,23 @@ class MultiheadMatMulOpConverter : public OpConverter {
if (engine_->with_dynamic_shape()) {
if (engine_->use_oss()) {
int head_size = hidden / head_number;
// [3, Nout, Hout, Nin, Hin] -> [Nout, 3, Hout, Nin, Hin]
auto transpose_weight_v2 = [](const float* src, float* dst, int N,
int H) {
const int HNH = H * N * H;
for (int i = 0; i < 3; ++i) {
for (int n = 0; n < N; ++n) {
for (int hnh = 0; hnh < HNH; ++hnh) {
dst[n * 3 * HNH + i * HNH + hnh] =
src[i * N * HNH + n * HNH + hnh];
int head_size = hidden_out / head_number;
// [3, head_number, head_size, hidden_in] -> [head_number, 3, head_size,
// hidden_in]
auto transpose_weight_v2 = [](const float* src, float* dst, int three,
int head_number, int head_size,
int hidden_in) {
const int HH = head_size * hidden_in;
for (int i = 0; i < three; ++i) {
for (int n = 0; n < head_number; ++n) {
for (int hh = 0; hh < HH; ++hh) {
dst[n * three * HH + i * HH + hh] =
src[i * head_number * HH + n * HH + hh];
}
}
}
};
// [3, N, H] -> [N, 3, H]
// [3, head_number, head_size] -> [head_number, 3, head_size]
auto transpose_bias_v2 = [](const float* src, float* dst, int N,
int H) {
for (int i = 0; i < 3; ++i) {
......@@ -99,8 +101,8 @@ class MultiheadMatMulOpConverter : public OpConverter {
};
memcpy(weight_data_tmp.data(), weight_data,
weight_t->numel() * sizeof(float));
transpose_weight_v2(weight_data_tmp.data(), weight_data, head_number,
head_size);
transpose_weight_v2(weight_data_tmp.data(), weight_data, three,
head_number, head_size, hidden_in);
nvinfer1::Weights weight{nvinfer1::DataType::kFLOAT,
static_cast<void*>(weight_data),
static_cast<int32_t>(weight_t->numel())};
......@@ -130,7 +132,7 @@ class MultiheadMatMulOpConverter : public OpConverter {
int var_seqlen = 1;
const std::vector<nvinfer1::PluginField> fields{
{"type_id", &type, nvinfer1::PluginFieldType::kINT32, 1},
{"hidden_size", &hidden, nvinfer1::PluginFieldType::kINT32, 1},
{"hidden_size", &hidden_out, nvinfer1::PluginFieldType::kINT32, 1},
{"num_heads", &head_number, nvinfer1::PluginFieldType::kINT32, 1},
{"has_mask", &has_mask, nvinfer1::PluginFieldType::kINT32, 1},
{"var_seqlen", &var_seqlen, nvinfer1::PluginFieldType::kINT32, 1},
......@@ -185,7 +187,7 @@ class MultiheadMatMulOpConverter : public OpConverter {
n, weight.get(), bias.get());
auto* fc_out = fc_layer->getOutput(0);
// add qkv to context
int head_size = all_head_size / head_number;
int head_size = hidden_out / head_number;
float scale = boost::get<float>(op_desc.GetAttr("alpha"));
std::vector<nvinfer1::ITensor*> plugin_inputs;
......@@ -194,7 +196,7 @@ class MultiheadMatMulOpConverter : public OpConverter {
bool with_fp16 =
engine_->WithFp16() && !engine_->disable_trt_plugin_fp16();
plugin::DynamicPluginTensorRT* plugin =
new plugin::QkvToContextPluginDynamic(hidden, head_number,
new plugin::QkvToContextPluginDynamic(hidden_in, head_number,
head_size, scale, with_fp16);
layer = engine_->AddPluginV2(plugin_inputs.data(), 2, plugin);
}
......
......@@ -54,7 +54,12 @@ nvinfer1::DimsExprs SpecialSlicePluginDynamic::getOutputDimensions(
int output_index, const nvinfer1::DimsExprs* inputs, int nb_inputs,
nvinfer1::IExprBuilder& expr_builder) {
nvinfer1::DimsExprs output(inputs[0]);
output.nbDims++;
for (int i = output.nbDims - 1; i > 1; i--) {
output.d[i] = inputs[0].d[i - 1];
}
auto one = expr_builder.constant(1);
output.d[1] = one;
output.d[0] = expr_builder.operation(nvinfer1::DimensionOperation::kSUB,
*inputs[1].d[0], *one);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册