diff --git a/paddle/fluid/inference/tensorrt/plugin/recover_padding_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/recover_padding_plugin.cu index 3963b48a26c6c7f9a3b747504f3d1c80cb2daa07..c6be871709452aefd5cbbb2db35323623b72bcf4 100644 --- a/paddle/fluid/inference/tensorrt/plugin/recover_padding_plugin.cu +++ b/paddle/fluid/inference/tensorrt/plugin/recover_padding_plugin.cu @@ -118,7 +118,28 @@ int RecoverPaddingPlugin::enqueue(const nvinfer1::PluginTensorDesc* inputDesc, const int32_t* input1 = static_cast(inputs[1]); // pos_id_tensor float* output = static_cast(outputs[0]); - const int32_t num_threads = 256; + int32_t num_threads; + if (input0_desc.dims.d[1] % 512 == 0) { + num_threads = 512; + } else if (input0_desc.dims.d[1] % 256 == 0) { + num_threads = 256; + } else if (input0_desc.dims.d[1] % 128 == 0) { + num_threads = 128; + } else if (input0_desc.dims.d[1] % 64 == 0) { + num_threads = 64; + } else if (input0_desc.dims.d[1] % 32 == 0) { + num_threads = 32; + } else if (input0_desc.dims.d[1] % 16 == 0) { + num_threads = 16; + } else if (input0_desc.dims.d[1] % 8 == 0) { + num_threads = 8; + } else if (input0_desc.dims.d[1] % 4 == 0) { + num_threads = 4; + } else if (input0_desc.dims.d[1] % 2 == 0) { + num_threads = 2; + } else { + num_threads = 1; + } const dim3 num_blocks( input1_desc.dims.d[0] - 1, input2_desc.dims.d[1], diff --git a/paddle/fluid/inference/tensorrt/plugin/remove_padding_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/remove_padding_plugin.cu index 418ecb015784fe21108ae2a29a08542a82ead32a..9f1a1d6d2c109a54a8419c787445daf156c81552 100644 --- a/paddle/fluid/inference/tensorrt/plugin/remove_padding_plugin.cu +++ b/paddle/fluid/inference/tensorrt/plugin/remove_padding_plugin.cu @@ -110,10 +110,29 @@ int RemovePaddingPlugin::enqueue(const nvinfer1::PluginTensorDesc* inputDesc, const int32_t* input1 = static_cast(inputs[1]); // pos_id_tensor float* output = static_cast(outputs[0]); - const auto input0_desc = inputDesc[0]; - - const int32_t num_threads = 256; + int32_t num_threads; + if (input0_desc.dims.d[2] % 512 == 0) { + num_threads = 512; + } else if (input0_desc.dims.d[2] % 256 == 0) { + num_threads = 256; + } else if (input0_desc.dims.d[2] % 128 == 0) { + num_threads = 128; + } else if (input0_desc.dims.d[2] % 64 == 0) { + num_threads = 64; + } else if (input0_desc.dims.d[2] % 32 == 0) { + num_threads = 32; + } else if (input0_desc.dims.d[2] % 16 == 0) { + num_threads = 16; + } else if (input0_desc.dims.d[2] % 8 == 0) { + num_threads = 8; + } else if (input0_desc.dims.d[2] % 4 == 0) { + num_threads = 4; + } else if (input0_desc.dims.d[2] % 2 == 0) { + num_threads = 2; + } else { + num_threads = 1; + } const dim3 num_blocks( input0_desc.dims.d[0], input0_desc.dims.d[1],