未验证 提交 c6950ab2 编写于 作者: F FlyingQianMM 提交者: GitHub

add index initialization in the block loop for index_sample kernel when...

add index initialization in the block loop for index_sample kernel when dealing with a input tensor whose shape is larger than block_dim * grid_dim (#39736)

* add block and grid loop for index_sample kernel to deal with a large-shape tensor

* fix code format

* limit grid dim

* fix the omissive initialization of index_i in the second cycle for index_sample kernel

* fix conflicts
上级 553afc07
...@@ -44,6 +44,7 @@ __global__ void IndexSampleForward(const IndexT* index, const T* in_data, ...@@ -44,6 +44,7 @@ __global__ void IndexSampleForward(const IndexT* index, const T* in_data,
unsigned int index_i = blockDim.x * blockIdx.x + threadIdx.x; unsigned int index_i = blockDim.x * blockIdx.x + threadIdx.x;
unsigned int index_j = blockDim.y * blockIdx.y + threadIdx.y; unsigned int index_j = blockDim.y * blockIdx.y + threadIdx.y;
for (; index_j < batch_size; index_j += blockDim.y * gridDim.y) { for (; index_j < batch_size; index_j += blockDim.y * gridDim.y) {
index_i = blockDim.x * blockIdx.x + threadIdx.x;
for (; index_i < index_length; index_i += blockDim.x * gridDim.x) { for (; index_i < index_length; index_i += blockDim.x * gridDim.x) {
unsigned int index_idx = index_j * index_length + index_i; unsigned int index_idx = index_j * index_length + index_i;
unsigned int in_idx = index_j * input_length + index_i; unsigned int in_idx = index_j * input_length + index_i;
...@@ -62,6 +63,7 @@ __global__ void IndexSampleGrad(const IndexT* index, T* in_grad, ...@@ -62,6 +63,7 @@ __global__ void IndexSampleGrad(const IndexT* index, T* in_grad,
unsigned int index_j = blockDim.y * blockIdx.y + threadIdx.y; unsigned int index_j = blockDim.y * blockIdx.y + threadIdx.y;
for (; index_j < batch_size; index_j += blockDim.y * gridDim.y) { for (; index_j < batch_size; index_j += blockDim.y * gridDim.y) {
index_i = blockDim.x * blockIdx.x + threadIdx.x;
for (; index_i < index_length; index_i += blockDim.x * gridDim.x) { for (; index_i < index_length; index_i += blockDim.x * gridDim.x) {
unsigned int index_idx = index_j * index_length + index_i; unsigned int index_idx = index_j * index_length + index_i;
unsigned int in_idx = index_j * input_length + index_i; unsigned int in_idx = index_j * input_length + index_i;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册