From 6b59a073daa3280e801976af4f584b3b6ae626fa Mon Sep 17 00:00:00 2001 From: Wangzheee <634486483@qq.com> Date: Mon, 19 Sep 2022 15:36:58 +0800 Subject: [PATCH] fix_recover_remove_padding kernel (#46050) (#46198) --- .../tensorrt/plugin/recover_padding_plugin.cu | 23 ++++++++++++++++- .../tensorrt/plugin/remove_padding_plugin.cu | 25 ++++++++++++++++--- 2 files changed, 44 insertions(+), 4 deletions(-) diff --git a/paddle/fluid/inference/tensorrt/plugin/recover_padding_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/recover_padding_plugin.cu index 3963b48a26..c6be871709 100644 --- a/paddle/fluid/inference/tensorrt/plugin/recover_padding_plugin.cu +++ b/paddle/fluid/inference/tensorrt/plugin/recover_padding_plugin.cu @@ -118,7 +118,28 @@ int RecoverPaddingPlugin::enqueue(const nvinfer1::PluginTensorDesc* inputDesc, const int32_t* input1 = static_cast(inputs[1]); // pos_id_tensor float* output = static_cast(outputs[0]); - const int32_t num_threads = 256; + int32_t num_threads; + if (input0_desc.dims.d[1] % 512 == 0) { + num_threads = 512; + } else if (input0_desc.dims.d[1] % 256 == 0) { + num_threads = 256; + } else if (input0_desc.dims.d[1] % 128 == 0) { + num_threads = 128; + } else if (input0_desc.dims.d[1] % 64 == 0) { + num_threads = 64; + } else if (input0_desc.dims.d[1] % 32 == 0) { + num_threads = 32; + } else if (input0_desc.dims.d[1] % 16 == 0) { + num_threads = 16; + } else if (input0_desc.dims.d[1] % 8 == 0) { + num_threads = 8; + } else if (input0_desc.dims.d[1] % 4 == 0) { + num_threads = 4; + } else if (input0_desc.dims.d[1] % 2 == 0) { + num_threads = 2; + } else { + num_threads = 1; + } const dim3 num_blocks( input1_desc.dims.d[0] - 1, input2_desc.dims.d[1], diff --git a/paddle/fluid/inference/tensorrt/plugin/remove_padding_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/remove_padding_plugin.cu index 418ecb0157..9f1a1d6d2c 100644 --- a/paddle/fluid/inference/tensorrt/plugin/remove_padding_plugin.cu +++ b/paddle/fluid/inference/tensorrt/plugin/remove_padding_plugin.cu @@ -110,10 +110,29 @@ int RemovePaddingPlugin::enqueue(const nvinfer1::PluginTensorDesc* inputDesc, const int32_t* input1 = static_cast(inputs[1]); // pos_id_tensor float* output = static_cast(outputs[0]); - const auto input0_desc = inputDesc[0]; - - const int32_t num_threads = 256; + int32_t num_threads; + if (input0_desc.dims.d[2] % 512 == 0) { + num_threads = 512; + } else if (input0_desc.dims.d[2] % 256 == 0) { + num_threads = 256; + } else if (input0_desc.dims.d[2] % 128 == 0) { + num_threads = 128; + } else if (input0_desc.dims.d[2] % 64 == 0) { + num_threads = 64; + } else if (input0_desc.dims.d[2] % 32 == 0) { + num_threads = 32; + } else if (input0_desc.dims.d[2] % 16 == 0) { + num_threads = 16; + } else if (input0_desc.dims.d[2] % 8 == 0) { + num_threads = 8; + } else if (input0_desc.dims.d[2] % 4 == 0) { + num_threads = 4; + } else if (input0_desc.dims.d[2] % 2 == 0) { + num_threads = 2; + } else { + num_threads = 1; + } const dim3 num_blocks( input0_desc.dims.d[0], input0_desc.dims.d[1], -- GitLab