未验证 提交 6b59a073 编写于 作者: W Wangzheee 提交者: GitHub

fix_recover_remove_padding kernel (#46050) (#46198)

上级 db368d5b
......@@ -118,7 +118,28 @@ int RecoverPaddingPlugin::enqueue(const nvinfer1::PluginTensorDesc* inputDesc,
const int32_t* input1 =
static_cast<const int32_t*>(inputs[1]); // pos_id_tensor
float* output = static_cast<float*>(outputs[0]);
const int32_t num_threads = 256;
int32_t num_threads;
if (input0_desc.dims.d[1] % 512 == 0) {
num_threads = 512;
} else if (input0_desc.dims.d[1] % 256 == 0) {
num_threads = 256;
} else if (input0_desc.dims.d[1] % 128 == 0) {
num_threads = 128;
} else if (input0_desc.dims.d[1] % 64 == 0) {
num_threads = 64;
} else if (input0_desc.dims.d[1] % 32 == 0) {
num_threads = 32;
} else if (input0_desc.dims.d[1] % 16 == 0) {
num_threads = 16;
} else if (input0_desc.dims.d[1] % 8 == 0) {
num_threads = 8;
} else if (input0_desc.dims.d[1] % 4 == 0) {
num_threads = 4;
} else if (input0_desc.dims.d[1] % 2 == 0) {
num_threads = 2;
} else {
num_threads = 1;
}
const dim3 num_blocks(
input1_desc.dims.d[0] - 1,
input2_desc.dims.d[1],
......
......@@ -110,10 +110,29 @@ int RemovePaddingPlugin::enqueue(const nvinfer1::PluginTensorDesc* inputDesc,
const int32_t* input1 =
static_cast<const int32_t*>(inputs[1]); // pos_id_tensor
float* output = static_cast<float*>(outputs[0]);
const auto input0_desc = inputDesc[0];
const int32_t num_threads = 256;
int32_t num_threads;
if (input0_desc.dims.d[2] % 512 == 0) {
num_threads = 512;
} else if (input0_desc.dims.d[2] % 256 == 0) {
num_threads = 256;
} else if (input0_desc.dims.d[2] % 128 == 0) {
num_threads = 128;
} else if (input0_desc.dims.d[2] % 64 == 0) {
num_threads = 64;
} else if (input0_desc.dims.d[2] % 32 == 0) {
num_threads = 32;
} else if (input0_desc.dims.d[2] % 16 == 0) {
num_threads = 16;
} else if (input0_desc.dims.d[2] % 8 == 0) {
num_threads = 8;
} else if (input0_desc.dims.d[2] % 4 == 0) {
num_threads = 4;
} else if (input0_desc.dims.d[2] % 2 == 0) {
num_threads = 2;
} else {
num_threads = 1;
}
const dim3 num_blocks(
input0_desc.dims.d[0],
input0_desc.dims.d[1],
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册