From 937c5ceec1906e1a8f8e21d2dd165f5280050b26 Mon Sep 17 00:00:00 2001 From: zmx Date: Mon, 1 Mar 2021 05:24:32 +0800 Subject: [PATCH] issue with the implementation of column_sum_reduce (#804) hi, i take a look at the code of column_sum_reduce, i have 2 questions: 1. the goal of column_sum_reduce is to get the column sum of inp matrix with shape[rows, width] and the result shape should be [width],right ? It seems that the judgment condition of pos is not suitable 2. the implementation of cuda kernel based on the asumption that, the thread with same threadIdx.y will group into a thread_block_tile, the blockDim is (32,32), i read the nvidia document https://on-demand.gputechconf.com/gtc/2017/presentation/s7622-Kyrylo-perelygin-robust-and-scalable-cuda.pdf, THREAD BLOCK TILE is a subset of threads of a thread block, divided into tiles in row-major order. doesn't it mean thread with the same threadIdx.x will group into a thread_block_tile ? thanks !!!! Co-authored-by: Reza Yazdani <44502768+RezaYazdaniAminabadi@users.noreply.github.com> --- csrc/transformer/general_kernels.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/transformer/general_kernels.cu b/csrc/transformer/general_kernels.cu index fbe4d053..7d318773 100644 --- a/csrc/transformer/general_kernels.cu +++ b/csrc/transformer/general_kernels.cu @@ -43,7 +43,7 @@ __global__ void column_sum_reduce(const T* __restrict__ inp, if (threadIdx.x == 0) { int pos = blockIdx.x * TILE_DIM + threadIdx.y; - if (pos < (rows * width)) out[pos] = sum; + if (pos < width) out[pos] = sum; } } -- GitLab