From 953024ff4fb81543825df952670c16109fb2bf90 Mon Sep 17 00:00:00 2001 From: zhoutianzi666 <39978853+zhoutianzi666@users.noreply.github.com> Date: Wed, 6 Jul 2022 09:42:59 +0800 Subject: [PATCH] fix stack_op_plugin (#44045) --- paddle/fluid/inference/tensorrt/plugin/stack_op_plugin.cu | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/inference/tensorrt/plugin/stack_op_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/stack_op_plugin.cu index 4ef160d2e04..e77f12769c0 100644 --- a/paddle/fluid/inference/tensorrt/plugin/stack_op_plugin.cu +++ b/paddle/fluid/inference/tensorrt/plugin/stack_op_plugin.cu @@ -152,8 +152,8 @@ __global__ void StackKernel(const T* const* input, T* output, int num_stack, int base_unit) { - int stack_id = blockIdx.x; - int lead_id = blockIdx.y; + int stack_id = blockIdx.y; + int lead_id = blockIdx.x; for (int i = threadIdx.x; i < base_unit; i += blockDim.x) { output[lead_id * num_stack * base_unit + stack_id * base_unit + i] = @@ -201,7 +201,8 @@ int StackPluginDynamic::enqueue(const nvinfer1::PluginTensorDesc* input_desc, stream); const int num_stacks = out_dims.d[axis_]; - dim3 num_blocks(num_stacks, lead_unit); + // lead_unit may be very large, so make it be blockIdx.x + dim3 num_blocks(lead_unit, num_stacks); const int num_threads = 256; auto infer_type = input_desc[0].type; -- GitLab