未验证 提交 6578da51 编写于 作者: W Wangzheee 提交者: GitHub

fix qk bias for multihead (#49702)

上级 2bb28f31
......@@ -131,12 +131,13 @@ class MultiheadMatMulOpConverter : public OpConverter {
step_dims.d[3] = 1;
auto* shape_tensor = Shape(bias_qk_tensor);
// (b,n,m,m) -> (b,1,m,1)
// (b,*,*,m) -> (b,1,1,m)
std::vector<nvinfer1::ITensor*> size_vec_tensor;
size_vec_tensor.push_back(GetEleTensorOfShape(shape_tensor, 0));
size_vec_tensor.push_back(Add1DConstantLayer(1));
size_vec_tensor.push_back(GetEleTensorOfShape(shape_tensor, 2));
size_vec_tensor.push_back(Add1DConstantLayer(1));
size_vec_tensor.push_back(GetEleTensorOfShape(shape_tensor, 3));
auto* size_tensor = Concat(size_vec_tensor);
auto* slice_layer = TRT_ENGINE_ADD_LAYER(engine_,
Slice,
......@@ -163,7 +164,7 @@ class MultiheadMatMulOpConverter : public OpConverter {
TRT_ENGINE_ADD_LAYER(engine_, Identity, *not_layer->getOutput(0));
cast_layer_1->setOutputType(0, nvinfer1::DataType::kINT32);
// Calculate the number of 1 : (b,1,m,1) -> (b)
// Calculate the number of 1 : (b,1,1,m) -> (b)
uint32_t reduce_dim_0 = 0;
reduce_dim_0 |= 1 << 1; // 00000000000000000000000000000010
reduce_dim_0 |= 1 << 2; // 00000000000000000000000000000110
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册