diff --git a/paddle/fluid/operators/gather_scatter_kernel.cu b/paddle/fluid/operators/gather_scatter_kernel.cu index dc87fc52aacb4c4ea3b263dbef2f37d0ef0347e0..f97eb3d5e9d9a348e3ad151cc234f0989c937985 100644 --- a/paddle/fluid/operators/gather_scatter_kernel.cu +++ b/paddle/fluid/operators/gather_scatter_kernel.cu @@ -119,7 +119,7 @@ struct gpu_gather_scatter_functor { is_scatter_like ? self_dims[dim] : src_dims[dim]; int64_t inner_dim_size = 1; int64_t outer_dim_size = 1; - for (int64_t i = 0; i < index_dims.size(); ++i) { + for (int64_t i = 0; i < dim; ++i) { inner_dim_size *= index_dims[i]; } @@ -127,11 +127,8 @@ struct gpu_gather_scatter_functor { outer_dim_size *= index_dims[i]; } - int64_t slice_size = 1; - for (int i = 1; i < src_dims.size(); ++i) slice_size *= src_dims[i]; - int block = 512; - int64_t n = slice_size * index_size; + int64_t n = inner_dim_size * select_dim_size * outer_dim_size; int64_t grid = (n + block - 1) / block; auto stream = reinterpret_cast(ctx).stream(); @@ -215,11 +212,8 @@ void gpu_scatter_input_grad_kernel(Tensor self, int dim, const Tensor& index, outer_dim_size *= index_dims[i]; } - int64_t slice_size = 1; - for (int i = 1; i < grad_dims.size(); ++i) slice_size *= grad_dims[i]; - int block = 512; - int64_t n = slice_size * index_size; + int64_t n = inner_dim_size * select_dim_size * outer_dim_size; int64_t grid = (n + block - 1) / block; auto stream = reinterpret_cast(ctx).stream();