From 1255e7d6fa6a6ca75821273c5839a657cd1a4757 Mon Sep 17 00:00:00 2001 From: Wangzheee <634486483@qq.com> Date: Thu, 24 Feb 2022 20:34:11 +0800 Subject: [PATCH] [Paddle-Inference] fix special_slice plugin (#39875) * fix plugin: special slice for ernie --- .../tensorrt/plugin/special_slice_plugin.cu | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/paddle/fluid/inference/tensorrt/plugin/special_slice_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/special_slice_plugin.cu index ecf06e9bf15..324e9c0392c 100644 --- a/paddle/fluid/inference/tensorrt/plugin/special_slice_plugin.cu +++ b/paddle/fluid/inference/tensorrt/plugin/special_slice_plugin.cu @@ -113,12 +113,12 @@ nvinfer1::DataType SpecialSlicePluginDynamic::getOutputDataType( template __global__ void SpecialSliceKernel(const T* slice_input, const int32_t* cu_seqlens, T* output) { - const int hidden = blockDim.x * gridDim.y; - const int batch = blockIdx.x; - const int local_idx = blockIdx.y * blockDim.y + threadIdx.x; + const int hidden = blockDim.x * gridDim.x; + const int hidden_id = blockIdx.x * blockDim.x + threadIdx.x; + const int batch_id = blockIdx.y; - output[batch * hidden + local_idx] = - slice_input[cu_seqlens[batch] * hidden + local_idx]; + output[batch_id * hidden + hidden_id] = + slice_input[cu_seqlens[batch_id] * hidden + hidden_id]; } int SpecialSlicePluginDynamic::enqueue( @@ -137,15 +137,16 @@ int SpecialSlicePluginDynamic::enqueue( "hidden should be multiple of 128.")); constexpr int num_threads = 128; - const dim3 blocks(out_dims.d[0], hidden / num_threads); - const half* slice_input = static_cast(inputs[0]); const int32_t* cu_seqlens = static_cast(inputs[1]); half* output = static_cast(outputs[0]); - SpecialSliceKernel<<>>(slice_input, - cu_seqlens, output); + const int32_t num_blocks_x = hidden / num_threads; + const int32_t num_blocks_y = out_dims.d[0]; // batchs + const dim3 num_blocks(num_blocks_x, num_blocks_y); // blocks + SpecialSliceKernel<<>>( + slice_input, cu_seqlens, output); return cudaGetLastError() != cudaSuccess; } -- GitLab