未验证 提交 953024ff 编写于 作者: Z zhoutianzi666 提交者: GitHub

fix stack_op_plugin (#44045)

上级 a0dc361c
...@@ -152,8 +152,8 @@ __global__ void StackKernel(const T* const* input, ...@@ -152,8 +152,8 @@ __global__ void StackKernel(const T* const* input,
T* output, T* output,
int num_stack, int num_stack,
int base_unit) { int base_unit) {
int stack_id = blockIdx.x; int stack_id = blockIdx.y;
int lead_id = blockIdx.y; int lead_id = blockIdx.x;
for (int i = threadIdx.x; i < base_unit; i += blockDim.x) { for (int i = threadIdx.x; i < base_unit; i += blockDim.x) {
output[lead_id * num_stack * base_unit + stack_id * base_unit + i] = output[lead_id * num_stack * base_unit + stack_id * base_unit + i] =
...@@ -201,7 +201,8 @@ int StackPluginDynamic::enqueue(const nvinfer1::PluginTensorDesc* input_desc, ...@@ -201,7 +201,8 @@ int StackPluginDynamic::enqueue(const nvinfer1::PluginTensorDesc* input_desc,
stream); stream);
const int num_stacks = out_dims.d[axis_]; const int num_stacks = out_dims.d[axis_];
dim3 num_blocks(num_stacks, lead_unit); // lead_unit may be very large, so make it be blockIdx.x
dim3 num_blocks(lead_unit, num_stacks);
const int num_threads = 256; const int num_threads = 256;
auto infer_type = input_desc[0].type; auto infer_type = input_desc[0].type;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册