From 49e4ee27e12df5080dc15e766f93f6a25f100e42 Mon Sep 17 00:00:00 2001 From: zlsh80826 Date: Mon, 15 Jun 2020 19:41:29 +0800 Subject: [PATCH] [Paddle-TRT] slice kernel optimization (#24783) * parallel move shared data test=develop * test=develop --- paddle/fluid/inference/tensorrt/plugin/slice_op_plugin.cu | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/paddle/fluid/inference/tensorrt/plugin/slice_op_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/slice_op_plugin.cu index 7b2b7b10f08..4fb1d824108 100644 --- a/paddle/fluid/inference/tensorrt/plugin/slice_op_plugin.cu +++ b/paddle/fluid/inference/tensorrt/plugin/slice_op_plugin.cu @@ -35,10 +35,8 @@ __global__ void SliceKernel(int num, int dims, const T *input, const int idx = blockIdx.x * blockDim.x + threadIdx.x; extern __shared__ int shared_data[]; - if (threadIdx.x == 0) { - for (int i = 0; i < dims * 3; i++) { - shared_data[i] = offsets_info[i]; - } + for (int i = threadIdx.x; i < dims * 3; i += blockDim.x) { + shared_data[i] = offsets_info[i]; } __syncthreads(); -- GitLab