From 0b54d54fd847da881116f3c8628ec449c5c0d5d3 Mon Sep 17 00:00:00 2001 From: Chen Weihang Date: Sat, 11 Jul 2020 00:01:09 +0800 Subject: [PATCH] Fix index overflow bug of the CUDA kernel loop increment (#25435) * fix softmax_with_cross_entropy cuda kernel overflow bug, test=develop * replace old macro & for condition, test=develop * polish details, test=develop --- paddle/fluid/framework/fleet/box_wrapper.cu | 3 -- paddle/fluid/framework/lod_tensor_test.cu | 5 +- paddle/fluid/operators/batch_fc_op.cu | 4 -- paddle/fluid/operators/bce_loss_op.cu | 8 +-- paddle/fluid/operators/cvm_op.cu | 4 -- paddle/fluid/operators/data_norm_op.cu | 4 -- .../operators/deformable_psroi_pooling_op.cu | 4 -- .../detection/anchor_generator_op.cu | 8 +-- .../detection/collect_fpn_proposals_op.cu | 3 +- .../detection/distribute_fpn_proposals_op.cu | 6 +-- .../detection/generate_proposals_op.cu | 5 +- .../fluid/operators/detection/prior_box_op.cu | 8 +-- .../detection/roi_perspective_transform_op.cu | 8 +-- .../detection/sigmoid_focal_loss_op.cu | 8 +-- paddle/fluid/operators/gather.cu.h | 8 +-- paddle/fluid/operators/gather_tree_op.cu | 6 +-- paddle/fluid/operators/histogram_op.cu | 4 -- paddle/fluid/operators/instance_norm_op.cu | 3 +- paddle/fluid/operators/linspace_op.cu | 6 +-- paddle/fluid/operators/lstm_unit_op.cu | 8 +-- paddle/fluid/operators/math/cross_entropy.cu | 3 +- paddle/fluid/operators/math/math_function.cu | 3 +- paddle/fluid/operators/mean_iou_op.cu | 8 +-- paddle/fluid/operators/metrics/auc_op.cu | 3 -- paddle/fluid/operators/nll_loss_op.cu | 12 ++--- .../operators/optimizers/lars_momentum_op.cu | 3 +- paddle/fluid/operators/optimizers/sgd_op.cu | 3 +- paddle/fluid/operators/pad2d_op.cu | 28 +++++------ paddle/fluid/operators/prelu_op.cu | 5 -- paddle/fluid/operators/range_op.cu | 6 +-- paddle/fluid/operators/rank_attention.cu.h | 4 -- paddle/fluid/operators/roi_align_op.cu | 8 +-- paddle/fluid/operators/scatter.cu.h | 9 ++-- .../sigmoid_cross_entropy_with_logits_op.cu | 10 ++-- .../softmax_with_cross_entropy_op.cu | 20 ++++---- paddle/fluid/operators/transpose_op.cu | 6 +-- paddle/fluid/platform/cuda_helper.h | 49 +++++++++++++++++++ paddle/fluid/platform/cuda_helper_test.cu | 10 ++-- 38 files changed, 115 insertions(+), 188 deletions(-) diff --git a/paddle/fluid/framework/fleet/box_wrapper.cu b/paddle/fluid/framework/fleet/box_wrapper.cu index c315abd737c..31809532a69 100644 --- a/paddle/fluid/framework/fleet/box_wrapper.cu +++ b/paddle/fluid/framework/fleet/box_wrapper.cu @@ -23,9 +23,6 @@ namespace paddle { namespace framework { -#define CUDA_KERNEL_LOOP(i, n) \ - for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \ - i += blockDim.x * gridDim.x) template __global__ void PullCopy( diff --git a/paddle/fluid/framework/lod_tensor_test.cu b/paddle/fluid/framework/lod_tensor_test.cu index 7d6ba984f6f..7f0f46b1bb3 100644 --- a/paddle/fluid/framework/lod_tensor_test.cu +++ b/paddle/fluid/framework/lod_tensor_test.cu @@ -22,10 +22,7 @@ #include "paddle/fluid/platform/place.h" __global__ void test(size_t* a, int size) { - for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < size; - i += blockDim.x * gridDim.x) { - a[i] *= 2; - } + CUDA_KERNEL_LOOP(i, size) { a[i] *= 2; } } TEST(LoD, data) { diff --git a/paddle/fluid/operators/batch_fc_op.cu b/paddle/fluid/operators/batch_fc_op.cu index 414eeef2a6f..9a39306ccad 100644 --- a/paddle/fluid/operators/batch_fc_op.cu +++ b/paddle/fluid/operators/batch_fc_op.cu @@ -24,10 +24,6 @@ namespace paddle { namespace operators { using framework::Tensor; -#define CUDA_KERNEL_LOOP(i, n) \ - for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \ - i += blockDim.x * gridDim.x) - const int CUDA_NUM_THREADS = 1024; static inline int GET_BLOCKS(const int N) { return (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS; diff --git a/paddle/fluid/operators/bce_loss_op.cu b/paddle/fluid/operators/bce_loss_op.cu index 179e194a9c5..8e30f4eb15b 100644 --- a/paddle/fluid/operators/bce_loss_op.cu +++ b/paddle/fluid/operators/bce_loss_op.cu @@ -24,14 +24,10 @@ namespace operators { using Tensor = framework::Tensor; -#define CUDA_1D_KERNEL_LOOP(i, n) \ - for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \ - i += blockDim.x * gridDim.x) - template __global__ void GPUBCELossForward(const T* x_data, const T* label_data, T* out_data, const int in_numel) { - CUDA_1D_KERNEL_LOOP(i, in_numel) { + CUDA_KERNEL_LOOP(i, in_numel) { T x = x_data[i]; T label = label_data[i]; T one = static_cast(1.); @@ -48,7 +44,7 @@ template __global__ void GPUBCELossBackward(const T* x_data, const T* label_data, const T* dout_data, T* dx_data, const int in_numel) { - CUDA_1D_KERNEL_LOOP(i, in_numel) { + CUDA_KERNEL_LOOP(i, in_numel) { T x = x_data[i]; T label = label_data[i]; T dout = dout_data[i]; diff --git a/paddle/fluid/operators/cvm_op.cu b/paddle/fluid/operators/cvm_op.cu index 1f8470caff1..75976c968c9 100644 --- a/paddle/fluid/operators/cvm_op.cu +++ b/paddle/fluid/operators/cvm_op.cu @@ -25,10 +25,6 @@ using platform::PADDLE_CUDA_NUM_THREADS; using Tensor = framework::Tensor; using LoDTensor = framework::LoDTensor; -#define CUDA_KERNEL_LOOP(i, n) \ - for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \ - i += blockDim.x * gridDim.x) - template __global__ void CvmComputeKernel(const bool use_cvm, const int64_t item_width, const T* X, T* Y, int64_t numel) { diff --git a/paddle/fluid/operators/data_norm_op.cu b/paddle/fluid/operators/data_norm_op.cu index 483bb5ec5c7..9e284b1dcda 100644 --- a/paddle/fluid/operators/data_norm_op.cu +++ b/paddle/fluid/operators/data_norm_op.cu @@ -30,10 +30,6 @@ using LoDTensor = framework::LoDTensor; using DataLayout = framework::DataLayout; using platform::PADDLE_CUDA_NUM_THREADS; -#define CUDA_KERNEL_LOOP(i, n) \ - for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \ - i += blockDim.x * gridDim.x) - inline int GET_BLOCKS(const int N) { return (N + PADDLE_CUDA_NUM_THREADS - 1) / PADDLE_CUDA_NUM_THREADS; } diff --git a/paddle/fluid/operators/deformable_psroi_pooling_op.cu b/paddle/fluid/operators/deformable_psroi_pooling_op.cu index e977c70bf4d..c1d4cc9d17a 100644 --- a/paddle/fluid/operators/deformable_psroi_pooling_op.cu +++ b/paddle/fluid/operators/deformable_psroi_pooling_op.cu @@ -40,10 +40,6 @@ namespace operators { using Tensor = framework::Tensor; using LoDTensor = framework::LoDTensor; -#define CUDA_KERNEL_LOOP(i, n) \ - for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \ - i += blockDim.x * gridDim.x) - const int CUDA_NUM_THREADS = 1024; static inline int GET_BLOCKS(const int N) { return (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS; diff --git a/paddle/fluid/operators/detection/anchor_generator_op.cu b/paddle/fluid/operators/detection/anchor_generator_op.cu index 3cc9bbeee1e..b4c27a63dbd 100644 --- a/paddle/fluid/operators/detection/anchor_generator_op.cu +++ b/paddle/fluid/operators/detection/anchor_generator_op.cu @@ -24,8 +24,7 @@ __global__ void GenAnchors(T* out, const T* aspect_ratios, const int ar_num, const int width, const T offset) { int num_anchors = as_num * ar_num; int box_num = height * width * num_anchors; - for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < box_num; - i += blockDim.x * gridDim.x) { + CUDA_KERNEL_LOOP(i, box_num) { int h_idx = i / (num_anchors * width); int w_idx = (i / num_anchors) % width; T stride_width = stride[0]; @@ -64,10 +63,7 @@ __global__ void GenAnchors(T* out, const T* aspect_ratios, const int ar_num, template __global__ void SetVariance(T* out, const T* var, const int vnum, const int num) { - for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < num; - i += blockDim.x * gridDim.x) { - out[i] = var[i % vnum]; - } + CUDA_KERNEL_LOOP(i, num) { out[i] = var[i % vnum]; } } template diff --git a/paddle/fluid/operators/detection/collect_fpn_proposals_op.cu b/paddle/fluid/operators/detection/collect_fpn_proposals_op.cu index 6fac90bf2da..35222a85cd3 100644 --- a/paddle/fluid/operators/detection/collect_fpn_proposals_op.cu +++ b/paddle/fluid/operators/detection/collect_fpn_proposals_op.cu @@ -40,8 +40,7 @@ static inline int NumBlocks(const int N) { static __global__ void GetLengthLoD(const int nthreads, const int* batch_ids, int* length_lod) { - for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (nthreads); - i += blockDim.x * gridDim.x) { + CUDA_KERNEL_LOOP(i, nthreads) { platform::CudaAtomicAdd(length_lod + batch_ids[i], 1); } } diff --git a/paddle/fluid/operators/detection/distribute_fpn_proposals_op.cu b/paddle/fluid/operators/detection/distribute_fpn_proposals_op.cu index 1a89af9697d..1e3cd9f36c5 100644 --- a/paddle/fluid/operators/detection/distribute_fpn_proposals_op.cu +++ b/paddle/fluid/operators/detection/distribute_fpn_proposals_op.cu @@ -31,10 +31,6 @@ using LoDTensor = framework::LoDTensor; static constexpr int kNumCUDAThreads = 64; static constexpr int kNumMaxinumNumBlocks = 4096; -#define CUDA_1D_KERNEL_LOOP(i, n) \ - for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \ - i += blockDim.x * gridDim.x) - int const BBoxSize = 4; static inline int NumBlocks(const int N) { @@ -48,7 +44,7 @@ __global__ void GPUDistFpnProposalsHelper( const int refer_level, const int refer_scale, const int max_level, const int min_level, int* roi_batch_id_data, int* sub_lod_list, int* target_lvls) { - CUDA_1D_KERNEL_LOOP(i, nthreads) { + CUDA_KERNEL_LOOP(i, nthreads) { const T* offset_roi = rois + i * BBoxSize; int roi_batch_ind = roi_batch_id_data[i]; // get the target level of current rois diff --git a/paddle/fluid/operators/detection/generate_proposals_op.cu b/paddle/fluid/operators/detection/generate_proposals_op.cu index aaa8dbfe602..fa7670f6d68 100644 --- a/paddle/fluid/operators/detection/generate_proposals_op.cu +++ b/paddle/fluid/operators/detection/generate_proposals_op.cu @@ -33,9 +33,6 @@ using LoDTensor = framework::LoDTensor; namespace { #define DIVUP(m, n) ((m) / (n) + ((m) % (n) > 0)) -#define CUDA_1D_KERNEL_LOOP(i, n) \ - for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \ - i += blockDim.x * gridDim.x) int const kThreadsPerBlock = sizeof(uint64_t) * 8; @@ -155,7 +152,7 @@ static __global__ void FilterBBoxes(const T *bboxes, const T *im_info, int cnt = 0; __shared__ int keep_index[BlockSize]; - CUDA_1D_KERNEL_LOOP(i, num) { + CUDA_KERNEL_LOOP(i, num) { keep_index[threadIdx.x] = -1; __syncthreads(); diff --git a/paddle/fluid/operators/detection/prior_box_op.cu b/paddle/fluid/operators/detection/prior_box_op.cu index 1ea8cfc1d2a..1ef37e87198 100644 --- a/paddle/fluid/operators/detection/prior_box_op.cu +++ b/paddle/fluid/operators/detection/prior_box_op.cu @@ -32,8 +32,7 @@ __global__ void GenPriorBox(T* out, const T* aspect_ratios, const int height, bool min_max_aspect_ratios_order) { int num_priors = max_sizes ? as_num * min_num + min_num : as_num * min_num; int box_num = height * width * num_priors; - for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < box_num; - i += blockDim.x * gridDim.x) { + CUDA_KERNEL_LOOP(i, box_num) { int h = i / (num_priors * width); int w = (i / num_priors) % width; int p = i % num_priors; @@ -87,10 +86,7 @@ __global__ void GenPriorBox(T* out, const T* aspect_ratios, const int height, template __global__ void SetVariance(T* out, const T* var, const int vnum, const int num) { - for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < num; - i += blockDim.x * gridDim.x) { - out[i] = var[i % vnum]; - } + CUDA_KERNEL_LOOP(i, num) { out[i] = var[i % vnum]; } } template diff --git a/paddle/fluid/operators/detection/roi_perspective_transform_op.cu b/paddle/fluid/operators/detection/roi_perspective_transform_op.cu index fe65162353e..7b34e197ffe 100644 --- a/paddle/fluid/operators/detection/roi_perspective_transform_op.cu +++ b/paddle/fluid/operators/detection/roi_perspective_transform_op.cu @@ -30,10 +30,6 @@ namespace operators { #define idx4_2(index, d1, d2, d3, d4) ((index / d4 / d3) % d2) #define idx4_1(index, d1, d2, d3, d4) ((index / d4 / d3 / d2) % d1) -#define CUDA_1D_KERNEL_LOOP(i, n) \ - for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \ - i += blockDim.x * gridDim.x) - template __device__ bool GT_E(T a, T b) { return (a > b) || Eigen::numext::abs(a - b) < 1e-4; @@ -284,7 +280,7 @@ __global__ void RoiTransformKernel(const float* input_data, int* mask, T* transform_matrix) { int output_size = num_rois * transformed_height * transformed_width * channels; - CUDA_1D_KERNEL_LOOP(index, output_size) { + CUDA_KERNEL_LOOP(index, output_size) { // (n, c, out_h, out_w) is an element in the transformed output int out_w = idx4_4(index, num_rois, channels, transformed_height, transformed_width); @@ -463,7 +459,7 @@ __global__ void RoiTransformGradKernel(int out_size, const int* out2in_idx_data, const T* out2in_w_data, const T* out_grad_data, T* in_grad_data) { - CUDA_1D_KERNEL_LOOP(index, out_size * 4) { + CUDA_KERNEL_LOOP(index, out_size * 4) { int in_idx = out2in_idx_data[index]; if (in_idx >= 0) { int out_idx = index / 4; diff --git a/paddle/fluid/operators/detection/sigmoid_focal_loss_op.cu b/paddle/fluid/operators/detection/sigmoid_focal_loss_op.cu index 4031554aa72..f12d60c8b0f 100644 --- a/paddle/fluid/operators/detection/sigmoid_focal_loss_op.cu +++ b/paddle/fluid/operators/detection/sigmoid_focal_loss_op.cu @@ -30,10 +30,6 @@ static inline int NumBlocks(const int N) { kNumMaxinumNumBlocks); } -#define CUDA_1D_KERNEL_LOOP(i, n) \ - for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \ - i += blockDim.x * gridDim.x) - template __global__ void GPUSigmoidFocalLossForward(const T *x_data, const int *label_data, @@ -41,7 +37,7 @@ __global__ void GPUSigmoidFocalLossForward(const T *x_data, const T gamma, const T alpha, const int num_classes, const int limit, T *out_data) { - CUDA_1D_KERNEL_LOOP(i, limit) { + CUDA_KERNEL_LOOP(i, limit) { T x = x_data[i]; int a = i / num_classes; // current sample int d = i % num_classes; // current class @@ -79,7 +75,7 @@ __global__ void GPUSigmoidFocalLossBackward( const T *x_data, const int *label_data, const int *fg_num_data, const T gamma, const T alpha, const int num_classes, const T *dout_data, const int limit, T *dx_data) { - CUDA_1D_KERNEL_LOOP(i, limit) { + CUDA_KERNEL_LOOP(i, limit) { T x = x_data[i]; T dout = dout_data[i]; diff --git a/paddle/fluid/operators/gather.cu.h b/paddle/fluid/operators/gather.cu.h index 979deb8919e..f59d46ec79b 100644 --- a/paddle/fluid/operators/gather.cu.h +++ b/paddle/fluid/operators/gather.cu.h @@ -27,15 +27,11 @@ namespace operators { using framework::Tensor; using platform::DeviceContext; -#define CUDA_1D_KERNEL_LOOP(i, n) \ - for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \ - i += blockDim.x * gridDim.x) - template __global__ void GatherCUDAKernel(const T* params, const IndexT* indices, T* output, size_t index_size, size_t slice_size) { - CUDA_1D_KERNEL_LOOP(i, index_size * slice_size) { + CUDA_KERNEL_LOOP(i, index_size * slice_size) { int indices_i = i / slice_size; int slice_i = i - indices_i * slice_size; // offset inside the slice IndexT gather_i = indices[indices_i]; @@ -49,7 +45,7 @@ __global__ void GatherNdCUDAKernel(const T* input, const int* input_dims, const IndexT* indices, T* output, size_t remain_size, size_t slice_size, size_t end_size) { - CUDA_1D_KERNEL_LOOP(i, remain_size * slice_size) { + CUDA_KERNEL_LOOP(i, remain_size * slice_size) { int indices_i = i / slice_size; int slice_i = i - indices_i * slice_size; // offset inside the slice IndexT gather_i = 0; diff --git a/paddle/fluid/operators/gather_tree_op.cu b/paddle/fluid/operators/gather_tree_op.cu index 7ea3641b99f..c53f1e81cef 100644 --- a/paddle/fluid/operators/gather_tree_op.cu +++ b/paddle/fluid/operators/gather_tree_op.cu @@ -19,15 +19,11 @@ limitations under the License. */ namespace paddle { namespace operators { -#define CUDA_1D_KERNEL_LOOP(i, n) \ - for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \ - i += blockDim.x * gridDim.x) - template __global__ void GatherTree(const T *ids_data, const T *parents_data, T *out_data, const int64_t max_length, const int64_t batch_size, const int64_t beam_size) { - CUDA_1D_KERNEL_LOOP(i, batch_size * beam_size) { + CUDA_KERNEL_LOOP(i, batch_size * beam_size) { int batch = i / beam_size; int beam = i % beam_size; auto idx = diff --git a/paddle/fluid/operators/histogram_op.cu b/paddle/fluid/operators/histogram_op.cu index 359e90bfc3a..3de24ead0de 100644 --- a/paddle/fluid/operators/histogram_op.cu +++ b/paddle/fluid/operators/histogram_op.cu @@ -27,10 +27,6 @@ using IndexType = int64_t; using Tensor = framework::Tensor; using platform::PADDLE_CUDA_NUM_THREADS; -#define CUDA_KERNEL_LOOP(i, n) \ - for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \ - i += blockDim.x * gridDim.x) - inline int GET_BLOCKS(const int N) { return (N + PADDLE_CUDA_NUM_THREADS - 1) / PADDLE_CUDA_NUM_THREADS; } diff --git a/paddle/fluid/operators/instance_norm_op.cu b/paddle/fluid/operators/instance_norm_op.cu index 83236712098..51313835eba 100644 --- a/paddle/fluid/operators/instance_norm_op.cu +++ b/paddle/fluid/operators/instance_norm_op.cu @@ -35,8 +35,7 @@ using BatchNormParamType = typename CudnnDataType::BatchNormParamType; template static __global__ void repeat_param(const T *input, T *output, const int repeat_num, const int C) { - for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < repeat_num * C; - i += blockDim.x * gridDim.x) { + CUDA_KERNEL_LOOP(i, repeat_num * C) { int index = i % C; output[i] = input[index]; } diff --git a/paddle/fluid/operators/linspace_op.cu b/paddle/fluid/operators/linspace_op.cu index 90bd17cda0e..47d4536dcfe 100644 --- a/paddle/fluid/operators/linspace_op.cu +++ b/paddle/fluid/operators/linspace_op.cu @@ -19,13 +19,9 @@ limitations under the License. */ namespace paddle { namespace operators { -#define CUDA_1D_KERNEL_LOOP(i, n) \ - for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \ - i += blockDim.x * gridDim.x) - template __global__ void LinspaceKernel(T start, T step, int64_t size, T* out) { - CUDA_1D_KERNEL_LOOP(index, size) { out[index] = start + step * index; } + CUDA_KERNEL_LOOP(index, size) { out[index] = start + step * index; } } template diff --git a/paddle/fluid/operators/lstm_unit_op.cu b/paddle/fluid/operators/lstm_unit_op.cu index 7d2279f16d3..810b83cb535 100644 --- a/paddle/fluid/operators/lstm_unit_op.cu +++ b/paddle/fluid/operators/lstm_unit_op.cu @@ -24,10 +24,6 @@ https://github.com/caffe2/caffe2/blob/master/caffe2/operators/lstm_unit_op_gpu.c namespace paddle { namespace operators { -#define CUDA_1D_KERNEL_LOOP(i, n) \ - for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \ - i += blockDim.x * gridDim.x) - template __device__ Dtype cuda_sigmoid(const Dtype x) { return Dtype(1) / (Dtype(1) + exp(-x)); @@ -42,7 +38,7 @@ template __global__ void LSTMUnitKernel(const int nthreads, const int dim, const T* C_prev, const T* X, T* C, T* H, const T forget_bias) { - CUDA_1D_KERNEL_LOOP(index, nthreads) { + CUDA_KERNEL_LOOP(index, nthreads) { const int n = index / dim; const int d = index % dim; @@ -65,7 +61,7 @@ __global__ void LSTMUnitGradientKernel(const int nthreads, const int dim, const T* C_diff, const T* H_diff, T* C_prev_diff, T* X_diff, const T forget_bias) { - CUDA_1D_KERNEL_LOOP(index, nthreads) { + CUDA_KERNEL_LOOP(index, nthreads) { const int n = index / dim; const int d = index % dim; const T* X_offset = X + 4 * dim * n; diff --git a/paddle/fluid/operators/math/cross_entropy.cu b/paddle/fluid/operators/math/cross_entropy.cu index 2d871c6e14b..c7fac60dd3e 100644 --- a/paddle/fluid/operators/math/cross_entropy.cu +++ b/paddle/fluid/operators/math/cross_entropy.cu @@ -25,8 +25,7 @@ template __global__ void CrossEntropyKernel(T* Y, const T* X, const int64_t* label, const int N, const int D, const int ignore_index) { - for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N; - i += blockDim.x * gridDim.x) { + CUDA_KERNEL_LOOP(i, N) { PADDLE_ENFORCE(label[i] >= 0 && label[i] < D || label[i] == ignore_index, "label[%d] expected >= 0 and < %ld, or == %ld, but got " "%ld. Please check input value.", diff --git a/paddle/fluid/operators/math/math_function.cu b/paddle/fluid/operators/math/math_function.cu index 235bbb57ed6..fba143d017d 100644 --- a/paddle/fluid/operators/math/math_function.cu +++ b/paddle/fluid/operators/math/math_function.cu @@ -75,8 +75,7 @@ template __global__ void RowwiseAddKernel(const T* a, const T* b, T* c, int width, int num) { T tmp = 1.0 / width; - for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < num; - i += blockDim.x * gridDim.x) { + CUDA_KERNEL_LOOP(i, num) { int h = i * tmp; int w = i - h * width; c[i] = a[i] + b[w]; diff --git a/paddle/fluid/operators/mean_iou_op.cu b/paddle/fluid/operators/mean_iou_op.cu index ada1892f43d..7098a720cc3 100644 --- a/paddle/fluid/operators/mean_iou_op.cu +++ b/paddle/fluid/operators/mean_iou_op.cu @@ -23,10 +23,6 @@ namespace operators { using platform::PADDLE_CUDA_NUM_THREADS; -#define CUDA_1D_KERNEL_LOOP(i, n) \ - for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \ - i += blockDim.x * gridDim.x) - template __global__ void CountCUDAKernel(const int num_classes, const int count, const T* predictions, const T* labels, @@ -42,7 +38,7 @@ __global__ void CountCUDAKernel(const int num_classes, const int count, T pred; T label; - CUDA_1D_KERNEL_LOOP(i, count) { + CUDA_KERNEL_LOOP(i, count) { pred = predictions[i]; label = labels[i]; if (pred == label) { @@ -68,7 +64,7 @@ __global__ void ComputeIoUCUDAKernel(const int num_classes, int* wrong, valid_count_c = 0; } __syncthreads(); - CUDA_1D_KERNEL_LOOP(i, num_classes) { + CUDA_KERNEL_LOOP(i, num_classes) { int wrong_n = wrong[i]; int correct_n = correct[i]; int denominator = wrong_n + correct_n; diff --git a/paddle/fluid/operators/metrics/auc_op.cu b/paddle/fluid/operators/metrics/auc_op.cu index 04af6c51c73..13da4ff0857 100644 --- a/paddle/fluid/operators/metrics/auc_op.cu +++ b/paddle/fluid/operators/metrics/auc_op.cu @@ -23,9 +23,6 @@ namespace operators { using platform::PADDLE_CUDA_NUM_THREADS; using Tensor = framework::Tensor; using LoDTensor = framework::LoDTensor; -#define CUDA_KERNEL_LOOP(i, n) \ - for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \ - i += blockDim.x * gridDim.x) __global__ void ClearObsoleteDataKernel(int64_t *pos, int64_t *neg, const int bucket_length, diff --git a/paddle/fluid/operators/nll_loss_op.cu b/paddle/fluid/operators/nll_loss_op.cu index 7b37239a339..3d618805f02 100644 --- a/paddle/fluid/operators/nll_loss_op.cu +++ b/paddle/fluid/operators/nll_loss_op.cu @@ -31,10 +31,6 @@ static inline int NumBlocks(const int N) { kNumMaxinumNumBlocks); } -#define CUDA_1D_KERNEL_LOOP(i, n) \ - for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \ - i += blockDim.x * gridDim.x) - template __global__ void GPUNLLLossForward1D_no_reduce(T* out_data, const T* x_data, const int64_t* label_data, @@ -42,7 +38,7 @@ __global__ void GPUNLLLossForward1D_no_reduce(T* out_data, const T* x_data, const int64_t batch_size, const int64_t n_classes, const int64_t ignore_index) { - CUDA_1D_KERNEL_LOOP(i, batch_size) { + CUDA_KERNEL_LOOP(i, batch_size) { const int64_t cur_label = label_data[i]; if (cur_label == ignore_index) { out_data[i] = 0; @@ -191,7 +187,7 @@ __global__ void GPUNLLLossForward2D_no_reduce( const int64_t map_size = in_dim2 * in_dim3; const int64_t sample_size = n_classes * map_size; const int64_t out_numel = batch_size * map_size; - CUDA_1D_KERNEL_LOOP(i, out_numel) { + CUDA_KERNEL_LOOP(i, out_numel) { const int64_t b = i % batch_size; const int64_t h = (i / batch_size) % in_dim2; const int64_t w = (i / (batch_size * in_dim2)) % in_dim3; @@ -261,7 +257,7 @@ __global__ void GPUNLLLossBackward1D_no_reduce( T* dx_data, const int64_t* label_data, const T* weight_data, const T* dout_data, const int64_t batch_size, const int64_t n_classes, const int64_t ignore_index) { - CUDA_1D_KERNEL_LOOP(i, batch_size) { + CUDA_KERNEL_LOOP(i, batch_size) { const int64_t cur_label = label_data[i]; if (cur_label == ignore_index) { continue; @@ -299,7 +295,7 @@ __global__ void GPUNLLLossBackward2D_no_reduce( const int64_t map_size = in_dim2 * in_dim3; const int64_t sample_size = n_classes * map_size; const int64_t out_numel = batch_size * map_size; - CUDA_1D_KERNEL_LOOP(i, out_numel) { + CUDA_KERNEL_LOOP(i, out_numel) { const int64_t b = i % batch_size; const int64_t h = (i / batch_size) % in_dim2; const int64_t w = (i / (batch_size * in_dim2)) % in_dim3; diff --git a/paddle/fluid/operators/optimizers/lars_momentum_op.cu b/paddle/fluid/operators/optimizers/lars_momentum_op.cu index a277d6ff2be..1dace4ed6ab 100644 --- a/paddle/fluid/operators/optimizers/lars_momentum_op.cu +++ b/paddle/fluid/operators/optimizers/lars_momentum_op.cu @@ -26,8 +26,7 @@ __global__ void MomentumLarsKernel(const T* p, const T* g, const T* v, const T* g_norm, T* p_out, T* v_out) { T lr = learning_rate[0]; T local_lr = learning_rate[0]; - for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < num; - i += blockDim.x * gridDim.x) { + CUDA_KERNEL_LOOP(i, num) { if (p_norm[0] > 0 && g_norm[0] > 0) { local_lr = lr * lars_coeff * p_norm[0] / (g_norm[0] + lars_weight_decay * p_norm[0]); diff --git a/paddle/fluid/operators/optimizers/sgd_op.cu b/paddle/fluid/operators/optimizers/sgd_op.cu index 96eb51903f0..b70f24e0e5e 100644 --- a/paddle/fluid/operators/optimizers/sgd_op.cu +++ b/paddle/fluid/operators/optimizers/sgd_op.cu @@ -25,8 +25,7 @@ template __global__ void SGDKernel(const T* g, const T* p, const T* learning_rate, const int num, T* p_out) { T lr = learning_rate[0]; - int grid_size = blockDim.x * gridDim.x; - for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < num; i += grid_size) { + CUDA_KERNEL_LOOP(i, num) { T g_data = g[i]; T p_data = p[i]; p_out[i] = p_data - lr * g_data; diff --git a/paddle/fluid/operators/pad2d_op.cu b/paddle/fluid/operators/pad2d_op.cu index c05d778fb29..a77d0a5650e 100644 --- a/paddle/fluid/operators/pad2d_op.cu +++ b/paddle/fluid/operators/pad2d_op.cu @@ -23,10 +23,6 @@ namespace operators { using platform::PADDLE_CUDA_NUM_THREADS; -#define CUDA_1D_KERNEL_LOOP(i, n) \ - for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \ - i += blockDim.x * gridDim.x) - using framework::Tensor; template @@ -36,7 +32,7 @@ __global__ void Pad2DConstNCHW(const int nthreads, const T* in_data, const int out_height, const int out_width, const int pad_top, const int pad_left, T value, T* out_data) { - CUDA_1D_KERNEL_LOOP(index, nthreads) { + CUDA_KERNEL_LOOP(index, nthreads) { int nc = index / out_width; const int out_w = index % out_width; const int out_h = nc % out_height; @@ -57,7 +53,7 @@ __global__ void Pad2DConstNHWC(const int nthreads, const T* in_data, const int out_height, const int out_width, const int pad_top, const int pad_left, T value, T* out_data) { - CUDA_1D_KERNEL_LOOP(index, nthreads) { + CUDA_KERNEL_LOOP(index, nthreads) { int n = index / channels; const int c = index % channels; const int out_w = n % out_width; @@ -81,7 +77,7 @@ __global__ void Pad2DReflectNCHW(const int nthreads, const T* in_data, const int out_height, const int out_width, const int pad_top, const int pad_left, T* out_data) { - CUDA_1D_KERNEL_LOOP(index, nthreads) { + CUDA_KERNEL_LOOP(index, nthreads) { int nc = index / out_width; const int out_w = index % out_width; const int out_h = nc % out_height; @@ -103,7 +99,7 @@ __global__ void Pad2DReflectNHWC(const int nthreads, const T* in_data, const int out_height, const int out_width, const int pad_top, const int pad_left, T* out_data) { - CUDA_1D_KERNEL_LOOP(index, nthreads) { + CUDA_KERNEL_LOOP(index, nthreads) { int n = index / channels; const int c = index % channels; const int out_w = n % out_width; @@ -128,7 +124,7 @@ __global__ void Pad2DEdgeNCHW(const int nthreads, const T* in_data, const int out_height, const int out_width, const int pad_top, const int pad_left, T* out_data) { - CUDA_1D_KERNEL_LOOP(index, nthreads) { + CUDA_KERNEL_LOOP(index, nthreads) { int nc = index / out_width; const int out_w = index % out_width; const int out_h = nc % out_height; @@ -146,7 +142,7 @@ __global__ void Pad2DEdgeNHWC(const int nthreads, const T* in_data, const int out_height, const int out_width, const int pad_top, const int pad_left, T* out_data) { - CUDA_1D_KERNEL_LOOP(index, nthreads) { + CUDA_KERNEL_LOOP(index, nthreads) { int n = index / channels; const int c = index % channels; const int out_w = n % out_width; @@ -167,7 +163,7 @@ __global__ void Pad2DGradConstNCHW(const int in_size, T* d_in_data, const int out_height, const int out_width, const int pad_top, const int pad_left, const T* d_out_data) { - CUDA_1D_KERNEL_LOOP(in_index, in_size) { + CUDA_KERNEL_LOOP(in_index, in_size) { int nc = in_index / in_width; const int out_w = in_index % in_width + pad_left; const int out_h = nc % in_height + pad_top; @@ -184,7 +180,7 @@ __global__ void Pad2DGradConstNHWC(const int in_size, T* d_in_data, const int out_height, const int out_width, const int pad_top, const int pad_left, const T* d_out_data) { - CUDA_1D_KERNEL_LOOP(in_index, in_size) { + CUDA_KERNEL_LOOP(in_index, in_size) { int n = in_index / channels; const int c = in_index % channels; const int out_w = n % in_width + pad_left; @@ -204,7 +200,7 @@ __global__ void Pad2DGradReflectNCHW(const int out_size, T* d_in_data, const int out_height, const int out_width, const int pad_top, const int pad_left, const T* d_out_data) { - CUDA_1D_KERNEL_LOOP(out_index, out_size) { + CUDA_KERNEL_LOOP(out_index, out_size) { int nc = out_index / out_width; const int out_w = out_index % out_width; const int out_h = nc % out_height; @@ -228,7 +224,7 @@ __global__ void Pad2DGradReflectNHWC(const int out_size, T* d_in_data, const int out_height, const int out_width, const int pad_top, const int pad_left, const T* d_out_data) { - CUDA_1D_KERNEL_LOOP(out_index, out_size) { + CUDA_KERNEL_LOOP(out_index, out_size) { const int c = out_index % channels; int n = out_index / channels; const int out_w = n % out_width; @@ -254,7 +250,7 @@ __global__ void Pad2DGradEdgeNCHW(const int out_size, T* d_in_data, const int out_height, const int out_width, const int pad_top, const int pad_left, const T* d_out_data) { - CUDA_1D_KERNEL_LOOP(out_index, out_size) { + CUDA_KERNEL_LOOP(out_index, out_size) { int nc = out_index / out_width; const int out_w = out_index % out_width; const int out_h = nc % out_height; @@ -274,7 +270,7 @@ __global__ void Pad2DGradEdgeNHWC(const int out_size, T* d_in_data, const int out_height, const int out_width, const int pad_top, const int pad_left, const T* d_out_data) { - CUDA_1D_KERNEL_LOOP(out_index, out_size) { + CUDA_KERNEL_LOOP(out_index, out_size) { const int c = out_index % channels; int n = out_index / channels; const int out_w = n % out_width; diff --git a/paddle/fluid/operators/prelu_op.cu b/paddle/fluid/operators/prelu_op.cu index 2e51b00b980..2f61c53f877 100644 --- a/paddle/fluid/operators/prelu_op.cu +++ b/paddle/fluid/operators/prelu_op.cu @@ -25,11 +25,6 @@ using Tensor = framework::Tensor; #define CUDA_NUM_THREADS 1024 -// CUDA: grid stride looping -#define CUDA_KERNEL_LOOP(i, n) \ - for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \ - i += blockDim.x * gridDim.x) - inline static int PADDLE_GET_BLOCKS(const int N) { return (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS; } diff --git a/paddle/fluid/operators/range_op.cu b/paddle/fluid/operators/range_op.cu index e2c03716d55..c527bc74eee 100644 --- a/paddle/fluid/operators/range_op.cu +++ b/paddle/fluid/operators/range_op.cu @@ -19,13 +19,9 @@ limitations under the License. */ namespace paddle { namespace operators { -#define CUDA_1D_KERNEL_LOOP(i, n) \ - for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \ - i += blockDim.x * gridDim.x) - template __global__ void RangeKernel(T start, T step, int64_t size, T* out) { - CUDA_1D_KERNEL_LOOP(index, size) { out[index] = start + step * index; } + CUDA_KERNEL_LOOP(index, size) { out[index] = start + step * index; } } template diff --git a/paddle/fluid/operators/rank_attention.cu.h b/paddle/fluid/operators/rank_attention.cu.h index 9de3de241dc..27fe67e73cd 100644 --- a/paddle/fluid/operators/rank_attention.cu.h +++ b/paddle/fluid/operators/rank_attention.cu.h @@ -19,10 +19,6 @@ limitations under the License. */ namespace paddle { namespace operators { -#define CUDA_KERNEL_LOOP(i, n) \ - for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \ - i += blockDim.x * gridDim.x) - const int CUDA_NUM_THREADS = 1024; static inline int GET_BLOCKS(const int N) { return (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS; diff --git a/paddle/fluid/operators/roi_align_op.cu b/paddle/fluid/operators/roi_align_op.cu index 4c868d22c78..f7ec13e5bcc 100644 --- a/paddle/fluid/operators/roi_align_op.cu +++ b/paddle/fluid/operators/roi_align_op.cu @@ -31,10 +31,6 @@ static inline int NumBlocks(const int N) { kNumMaxinumNumBlocks); } -#define CUDA_1D_KERNEL_LOOP(i, n) \ - for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \ - i += blockDim.x * gridDim.x) - template __device__ T BilinearInterpolate(const T* input_data, const int height, const int width, T y, T x) { @@ -110,7 +106,7 @@ __global__ void GPUROIAlignForward( const float spatial_scale, const int channels, const int height, const int width, const int pooled_height, const int pooled_width, const int sampling_ratio, int* roi_batch_id_data, T* output_data) { - CUDA_1D_KERNEL_LOOP(i, nthreads) { + CUDA_KERNEL_LOOP(i, nthreads) { int pw = i % pooled_width; int ph = (i / pooled_width) % pooled_height; int c = (i / pooled_width / pooled_height) % channels; @@ -165,7 +161,7 @@ __global__ void GPUROIAlignBackward(const int nthreads, const T* input_rois, const int pooled_width, const int sampling_ratio, int* roi_batch_id_data, T* input_grad) { - CUDA_1D_KERNEL_LOOP(i, nthreads) { + CUDA_KERNEL_LOOP(i, nthreads) { int pw = i % pooled_width; int ph = (i / pooled_width) % pooled_height; int c = (i / pooled_width / pooled_height) % channels; diff --git a/paddle/fluid/operators/scatter.cu.h b/paddle/fluid/operators/scatter.cu.h index 9de810154e6..7890d50e109 100644 --- a/paddle/fluid/operators/scatter.cu.h +++ b/paddle/fluid/operators/scatter.cu.h @@ -26,14 +26,11 @@ namespace operators { using Tensor = framework::Tensor; -#define CUDA_1D_KERNEL_LOOP(i, n) \ - for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \ - i += blockDim.x * gridDim.x) template __global__ void ScatterInitCUDAKernel(const IndexT* indices, T* output, size_t index_size, size_t slice_size, bool overwrite) { - CUDA_1D_KERNEL_LOOP(i, index_size * slice_size) { + CUDA_KERNEL_LOOP(i, index_size * slice_size) { int indices_i = i / slice_size; int slice_i = i - indices_i * slice_size; // offset inside the slice IndexT scatter_i = indices[indices_i]; @@ -46,7 +43,7 @@ template __global__ void ScatterCUDAKernel(const T* params, const IndexT* indices, T* output, size_t index_size, size_t slice_size, bool overwrite) { - CUDA_1D_KERNEL_LOOP(i, index_size * slice_size) { + CUDA_KERNEL_LOOP(i, index_size * slice_size) { int indices_i = i / slice_size; int slice_i = i - indices_i * slice_size; // offset inside the slice IndexT scatter_i = indices[indices_i]; @@ -64,7 +61,7 @@ __global__ void ScatterNdCUDAKernel(const T* update, const IndexT* indices, T* output, const int* output_dims, size_t remain_size, size_t slice_size, size_t end_size) { - CUDA_1D_KERNEL_LOOP(i, remain_size * slice_size) { + CUDA_KERNEL_LOOP(i, remain_size * slice_size) { int indices_i = i / slice_size; int slice_i = i - indices_i * slice_size; // offset inside the slice IndexT gather_i = 0; diff --git a/paddle/fluid/operators/sigmoid_cross_entropy_with_logits_op.cu b/paddle/fluid/operators/sigmoid_cross_entropy_with_logits_op.cu index 7c3a0ecba02..cdcd51904e8 100644 --- a/paddle/fluid/operators/sigmoid_cross_entropy_with_logits_op.cu +++ b/paddle/fluid/operators/sigmoid_cross_entropy_with_logits_op.cu @@ -31,15 +31,11 @@ static inline int NumBlocks(const int N) { kNumMaxinumNumBlocks); } -#define CUDA_1D_KERNEL_LOOP(i, n) \ - for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \ - i += blockDim.x * gridDim.x) - template __global__ void GPUSigmoidForward(const T *x_data, const T *label_data, const int ignore_index, const int limit, T *out_data, T *counts) { - CUDA_1D_KERNEL_LOOP(i, limit) { + CUDA_KERNEL_LOOP(i, limit) { T x = x_data[i]; T label = label_data[i]; T eps = static_cast(1e-5); @@ -77,14 +73,14 @@ __global__ void Sum(const T *counts, int num, const T eps, T *sum) { template __global__ void Div(T *loss, const int num, const T *norm) { - CUDA_1D_KERNEL_LOOP(i, num) { loss[i] /= norm[0]; } + CUDA_KERNEL_LOOP(i, num) { loss[i] /= norm[0]; } } template __global__ void GPUSigmoidBackward(const T *x_data, const T *label_data, const int ignore_index, const T *dout_data, const int limit, T *dx_data, T *counts) { - CUDA_1D_KERNEL_LOOP(i, limit) { + CUDA_KERNEL_LOOP(i, limit) { T x = x_data[i]; T label = label_data[i]; T dout = dout_data[i]; diff --git a/paddle/fluid/operators/softmax_with_cross_entropy_op.cu b/paddle/fluid/operators/softmax_with_cross_entropy_op.cu index dbda4b9b7e0..344dfe23996 100644 --- a/paddle/fluid/operators/softmax_with_cross_entropy_op.cu +++ b/paddle/fluid/operators/softmax_with_cross_entropy_op.cu @@ -24,24 +24,22 @@ template __global__ void CrossEntropyGrad(T* logit_grad, const int64_t* labels, const int n, const int d, const int remain, const int ignore_index) { - for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n * remain; - i += blockDim.x * gridDim.x) { - int idx_n = i / remain; - int idx_remain = i % remain; - int idx = idx_n * d + labels[i] * remain + idx_remain; + CUDA_KERNEL_LOOP(index, n * remain) { + int idx_n = index / remain; + int idx_remain = index % remain; + int idx = idx_n * d + labels[index] * remain + idx_remain; logit_grad[idx] -= - ignore_index == labels[i] ? static_cast(0.) : static_cast(1.); + ignore_index == labels[index] ? static_cast(0.) : static_cast(1.); } } template __global__ void Scale(T* logit_grad, const T* loss_grad, const int num, const int d, const int remain) { - for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < num; - i += blockDim.x * gridDim.x) { - int idx_n = i / d; - int idx_remain = i % remain; - logit_grad[i] *= loss_grad[idx_n * remain + idx_remain]; + CUDA_KERNEL_LOOP(index, num) { + int idx_n = index / d; + int idx_remain = index % remain; + logit_grad[index] *= loss_grad[idx_n * remain + idx_remain]; } } diff --git a/paddle/fluid/operators/transpose_op.cu b/paddle/fluid/operators/transpose_op.cu index f2d39a35c3d..e9e55c20fc5 100644 --- a/paddle/fluid/operators/transpose_op.cu +++ b/paddle/fluid/operators/transpose_op.cu @@ -29,10 +29,6 @@ using Tensor = framework::Tensor; using Dim3 = framework::Dim3; using Index3 = framework::Index3; -#define CUDA_1D_KERNEL_LOOP(i, n) \ - for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \ - i += blockDim.x * gridDim.x) - struct EqualTo { constexpr bool operator()(int a, int b) const { return a == b; } }; @@ -464,7 +460,7 @@ __global__ void TransposeSimpleKernel(int nthreads, const T* __restrict__ input, output_dims[pos1] = input_dims[1]; output_dims[pos2] = input_dims[2]; - CUDA_1D_KERNEL_LOOP(output_index, nthreads) { + CUDA_KERNEL_LOOP(output_index, nthreads) { Index3 output_tensor_index = ConvertTensorIndex(output_index, output_dims); Index3 input_tensor_index; diff --git a/paddle/fluid/platform/cuda_helper.h b/paddle/fluid/platform/cuda_helper.h index 74cf5545239..6b3f91d5205 100644 --- a/paddle/fluid/platform/cuda_helper.h +++ b/paddle/fluid/platform/cuda_helper.h @@ -17,6 +17,7 @@ #include // NOLINT #include "paddle/fluid/platform/dynload/cublas.h" +#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/macros.h" #if CUDA_VERSION < 9000 @@ -26,6 +27,54 @@ enum cublasMath_t { CUBLAS_DEFAULT_MATH = 0 }; namespace paddle { namespace platform { +/* + * Summary: Grid stride looping macro in CUDA kernel + * + * [ Why need this macro? ] + * + * The original looping in CUDA kernel is: + * + * `for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \ + * i += blockDim.x * gridDim.x)` + * + * This for condition is risky. The value of `blockIdx.x * blockDim.x` + * may be large, such as over 1GB, the first iteration is no problem here, + * but when `i += blockDim.x * gridDim.x` is executed, the value of i + * will greater than INT_MAX and overflow becomes negative value, at + * this time, the cycle condition `i < (n)` is still satisfied, so it + * will cause illegal access to cuda memory. + * + * Here is a real example in ERINE, it will trigger above error. + * The related data are: + * - blockIdx.x = 2172938 + * - blockDim.x = 512 + * - blockIdx.x * blockDim.x = 1112543864 + * - INT_MAX = 2147483647 + * + * So we polish the for condition as follow, the int64_t __index__ will + * prevent overflow in the loop increment. + * + * Parameters: + * - i: loop index + * - num: total element numbers + * + * Examples: + * template + * __global__ void Scale(T* logit_grad, const T* loss_grad, const int num, + * const int d, const int remain) { + * CUDA_KERNEL_LOOP(index, num) { + * int idx_n = index / d; + * int idx_remain = index % remain; + * logit_grad[index] *= loss_grad[idx_n * remain + idx_remain]; + * } + * } + * +*/ +#define CUDA_KERNEL_LOOP(i, num) \ + int64_t __index__ = blockIdx.x * blockDim.x + threadIdx.x; \ + for (int i = __index__; __index__ < (num); \ + __index__ += blockDim.x * gridDim.x, i = __index__) + class CublasHandleHolder { public: CublasHandleHolder(cudaStream_t stream, cublasMath_t math_type) { diff --git a/paddle/fluid/platform/cuda_helper_test.cu b/paddle/fluid/platform/cuda_helper_test.cu index 9e3025bf30b..044f4d6748e 100644 --- a/paddle/fluid/platform/cuda_helper_test.cu +++ b/paddle/fluid/platform/cuda_helper_test.cu @@ -25,13 +25,14 @@ #include "paddle/fluid/platform/cuda_primitives.h" #include "paddle/fluid/platform/float16.h" +#include "paddle/fluid/platform/cuda_helper.h" + using paddle::platform::PADDLE_CUDA_NUM_THREADS; using paddle::platform::float16; template __global__ void AddKernel(const T* data_a, T* data_b, size_t num) { - for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < num; - i += blockDim.x * gridDim.x) { + CUDA_KERNEL_LOOP(i, num) { paddle::platform::CudaAtomicAdd(&data_b[i], data_a[i]); } } @@ -191,10 +192,7 @@ __forceinline__ __device__ T BlockReduce(T val) { template __global__ void DeviceReduceSum(T* in, T* out, size_t N) { T sum(0); - for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N; - i += blockDim.x * gridDim.x) { - sum += in[i]; - } + CUDA_KERNEL_LOOP(i, N) { sum += in[i]; } sum = BlockReduce(sum); __syncthreads(); if (threadIdx.x == 0) out[blockIdx.x] = sum; -- GitLab