From 30a2e7f0e03dd3be8d191c2ed3fd57a73776d748 Mon Sep 17 00:00:00 2001 From: Zhong Hui Date: Tue, 23 Feb 2021 19:37:54 +0800 Subject: [PATCH] [cherry-pick] Fix softmax cross entropy integer overflow. (#30590) (#31134) [BUG FIX] Fix softmax cross entropy overflow problem. --- paddle/fluid/operators/log_softmax_op.h | 8 +- .../softmax_with_cross_entropy_op.cu | 164 +++++++++--------- paddle/fluid/platform/cuda_helper.h | 7 +- paddle/fluid/platform/for_range.h | 10 +- 4 files changed, 96 insertions(+), 93 deletions(-) diff --git a/paddle/fluid/operators/log_softmax_op.h b/paddle/fluid/operators/log_softmax_op.h index b983ac54157..c732ec5a2da 100644 --- a/paddle/fluid/operators/log_softmax_op.h +++ b/paddle/fluid/operators/log_softmax_op.h @@ -29,16 +29,16 @@ static inline int CanonicalAxis(const int axis, const int rank) { return axis; } -static inline int SizeToAxis(const int axis, const framework::DDim dims) { - int size = 1; +static inline size_t SizeToAxis(const int axis, const framework::DDim dims) { + size_t size = 1; for (int i = 0; i < axis; i++) { size *= dims[i]; } return size; } -static inline int SizeFromAxis(const int axis, const framework::DDim dims) { - int size = 1; +static inline size_t SizeFromAxis(const int axis, const framework::DDim dims) { + size_t size = 1; for (int i = axis; i < dims.size(); i++) { size *= dims[i]; } diff --git a/paddle/fluid/operators/softmax_with_cross_entropy_op.cu b/paddle/fluid/operators/softmax_with_cross_entropy_op.cu index f86f02544dc..cb4eeab56a6 100644 --- a/paddle/fluid/operators/softmax_with_cross_entropy_op.cu +++ b/paddle/fluid/operators/softmax_with_cross_entropy_op.cu @@ -22,27 +22,27 @@ using Tensor = framework::Tensor; namespace { template __global__ void CrossEntropyGrad(T* logit_grad, const int64_t* labels, - const int n, const int d, const int remain, - const int ignore_index) { - CUDA_KERNEL_LOOP(index, n * remain) { - int idx_n = index / remain; - int idx_remain = index % remain; - int tmp = labels[index]; + const int64_t n, const int64_t d, + const int64_t remain, const int ignore_index) { + CUDA_KERNEL_LOOP_TYPE(index, n * remain, int64_t) { + int64_t idx_n = index / remain; + int64_t idx_remain = index % remain; + int64_t tmp = labels[index]; if (ignore_index != tmp) { - int idx = idx_n * d + tmp * remain + idx_remain; + int64_t idx = idx_n * d + tmp * remain + idx_remain; logit_grad[idx] -= static_cast(1.); } } } template -__global__ void Scale(T* logit_grad, const T* loss_grad, const int num, - const int d, const int remain, const int64_t* labels, - const int ignore_index) { - CUDA_KERNEL_LOOP(index, num) { - int idx_n = index / d; - int idx_remain = index % remain; - int idx_lbl = idx_n * remain + idx_remain; +__global__ void Scale(T* logit_grad, const T* loss_grad, const int64_t num, + const int64_t d, const int64_t remain, + const int64_t* labels, const int ignore_index) { + CUDA_KERNEL_LOOP_TYPE(index, num, int64_t) { + int64_t idx_n = index / d; + int64_t idx_remain = index % remain; + int64_t idx_lbl = idx_n * remain + idx_remain; if (labels[idx_lbl] == ignore_index) { logit_grad[index] = static_cast(0.); } else { @@ -54,13 +54,14 @@ __global__ void Scale(T* logit_grad, const T* loss_grad, const int num, template __global__ void SoftCrossEntropyGradientKernel(T* logit_grad, const T* loss_grad, - const T* labels, const int n, - const int d, const int remain) { - int ids = blockIdx.x * blockDim.x + threadIdx.x; + const T* labels, const int64_t n, + const int64_t d, + const int64_t remain) { + int64_t ids = blockIdx.x * blockDim.x + threadIdx.x; if (ids < n * d) { - int idx_n = ids / d; - int idx_remain = ids % remain; - int idx_loss = idx_n * remain + idx_remain; + int64_t idx_n = ids / d; + int64_t idx_remain = ids % remain; + int64_t idx_loss = idx_n * remain + idx_remain; logit_grad[ids] = loss_grad[idx_loss] * (logit_grad[ids] - labels[ids]); } } @@ -132,19 +133,19 @@ using BlockReduceTempStorage = typename BlockReduce::TempStorage; // This kernel is used to calculate the max element of each row template static __global__ void RowReductionForMax(const T* logits_data, T* max_data, - int d, int axis_dim) { + int64_t d, int axis_dim) { __shared__ BlockReduceTempStorage temp_storage; // logits_data view as [n, axis_dim, remain] // max_data view as [n, 1, remain] // blockDim = n * remain, split blockIdx to idx_n and idx_remain - int remain = d / axis_dim; - int idx_n = blockIdx.x / remain; - int idx_remain = blockIdx.x % remain; - int beg_idx = idx_n * d + threadIdx.x * remain + idx_remain; - int end_idx = (idx_n + 1) * d; + int64_t remain = d / axis_dim; + int64_t idx_n = blockIdx.x / remain; + int64_t idx_remain = blockIdx.x % remain; + int64_t beg_idx = idx_n * d + threadIdx.x * remain + idx_remain; + int64_t end_idx = (idx_n + 1) * d; - int step = BlockDim * remain; + int64_t step = BlockDim * remain; T cur_max = logits_data[beg_idx]; beg_idx += step; while (beg_idx < end_idx) { @@ -162,21 +163,21 @@ static __global__ void RowReductionForMax(const T* logits_data, T* max_data, // Make sure that BlockDim <= axis_dim template static __global__ void RowReductionForDiffMaxSum(const T* logits_data, - T* max_data, T* softmax, int d, - int axis_dim) { + T* max_data, T* softmax, + int64_t d, int axis_dim) { __shared__ BlockReduceTempStorage temp_storage; // logits, softmax data view as [n, axis_dim, remain] // max_data view as [n, 1, remain] // blockDim = n * remain, split blockIdx to idx_n and idx_remain - int remain = d / axis_dim; - int idx_n = blockIdx.x / remain; - int idx_remain = blockIdx.x % remain; - int beg_idx = idx_n * d + threadIdx.x * remain + idx_remain; - int end_idx = (idx_n + 1) * d; + int64_t remain = d / axis_dim; + int64_t idx_n = blockIdx.x / remain; + int64_t idx_remain = blockIdx.x % remain; + int64_t beg_idx = idx_n * d + threadIdx.x * remain + idx_remain; + int64_t end_idx = (idx_n + 1) * d; auto block_max = max_data[blockIdx.x]; - int step = BlockDim * remain; + int64_t step = BlockDim * remain; // In numeric stable mode softmax_with_loss, we calc loss with // tmp_i_j = x_i_j - max_i - logDiffMaxSum_i, instead of @@ -216,25 +217,25 @@ static __global__ void RowReductionForDiffMaxSum(const T* logits_data, // Make sure that BlockDim <= axis_dim template static __global__ void RowReductionForSoftmaxAndCrossEntropy( - const T* logits_data, const T* labels_data, T* loss_data, T* softmax, int d, - int axis_dim) { + const T* logits_data, const T* labels_data, T* loss_data, T* softmax, + int64_t d, int axis_dim) { __shared__ BlockReduceTempStorage temp_storage; // logits, softmax, labels data view as [n, axis_dim, remain] // loss_data view as [n, 1, remain] // blockDim = n * remain, split blockIdx to idx_n and idx_remain - int remain = d / axis_dim; - int idx_n = blockIdx.x / remain; - int idx_remain = blockIdx.x % remain; - int beg_idx = idx_n * d + threadIdx.x * remain + idx_remain; - int end_idx = (idx_n + 1) * d; + int64_t remain = d / axis_dim; + int64_t idx_n = blockIdx.x / remain; + int64_t idx_remain = blockIdx.x % remain; + int64_t beg_idx = idx_n * d + threadIdx.x * remain + idx_remain; + int64_t end_idx = (idx_n + 1) * d; // log_diff_max_sum shares memory with loss auto block_log_diff_max_sum = loss_data[blockIdx.x]; auto tmp = softmax[beg_idx] - block_log_diff_max_sum; softmax[beg_idx] = exp_on_device(tmp); auto loss = -labels_data[beg_idx] * tmp; - int step = BlockDim * remain; + int64_t step = BlockDim * remain; beg_idx += step; while (beg_idx < end_idx) { tmp = softmax[beg_idx] - block_log_diff_max_sum; @@ -251,21 +252,22 @@ template struct HardLabelSoftmaxWithCrossEntropyFunctor { public: HardLabelSoftmaxWithCrossEntropyFunctor(const int64_t* labels, T* loss, - T* log_softmax, int d, int axis_dim) + T* log_softmax, int64_t d, + int axis_dim) : labels_(labels), loss_(loss), log_softmax_(log_softmax), d_(d), axis_dim_(axis_dim) {} - __device__ void operator()(int idx) const { + __device__ void operator()(int64_t idx) const { // logits view as [n, axis_dim, remain], where d = axis_dim * remain - int remain = d_ / axis_dim_; - int idx_n = idx / d_; - int idx_axis = (idx % d_) / remain; - int idx_remain = idx % remain; + int64_t remain = d_ / axis_dim_; + int64_t idx_n = idx / d_; + int64_t idx_axis = (idx % d_) / remain; + int64_t idx_remain = idx % remain; // labels, loss view as [n, remain] - int idx_lbl = idx_n * remain + idx_remain; + int64_t idx_lbl = idx_n * remain + idx_remain; // It also would ignore labels not in range(class_num). if (idx_axis != labels_[idx_lbl]) { log_softmax_[idx] = exp_on_device(log_softmax_[idx]); @@ -280,7 +282,7 @@ struct HardLabelSoftmaxWithCrossEntropyFunctor { const int64_t* labels_; T* loss_; T* log_softmax_; - int d_; + int64_t d_; int axis_dim_; }; @@ -289,7 +291,7 @@ struct HardLabelSoftmaxWithCrossEntropyFunctorWithIgnoreIdx { public: HardLabelSoftmaxWithCrossEntropyFunctorWithIgnoreIdx(const int64_t* labels, T* loss, T* log_softmax, - int d, int axis_dim, + int64_t d, int axis_dim, int ignore_idx) : labels_(labels), loss_(loss), @@ -298,14 +300,14 @@ struct HardLabelSoftmaxWithCrossEntropyFunctorWithIgnoreIdx { axis_dim_(axis_dim), ignore_idx_(ignore_idx) {} - __device__ void operator()(int idx) const { + __device__ void operator()(int64_t idx) const { // logits view as [n, axis_dim, remain], where d = axis_dim * remain - int remain = d_ / axis_dim_; - int idx_n = idx / d_; - int idx_axis = (idx % d_) / remain; - int idx_remain = idx % remain; + int64_t remain = d_ / axis_dim_; + int64_t idx_n = idx / d_; + int64_t idx_axis = (idx % d_) / remain; + int64_t idx_remain = idx % remain; // labels, loss view as [n, remain] - int idx_lbl = idx_n * remain + idx_remain; + int64_t idx_lbl = idx_n * remain + idx_remain; if (idx_axis != labels_[idx_lbl] || idx_axis == ignore_idx_) { log_softmax_[idx] = exp_on_device(log_softmax_[idx]); } else { @@ -319,7 +321,7 @@ struct HardLabelSoftmaxWithCrossEntropyFunctorWithIgnoreIdx { const int64_t* labels_; T* loss_; T* log_softmax_; - int d_; + int64_t d_; int axis_dim_; int ignore_idx_; }; @@ -327,13 +329,13 @@ struct HardLabelSoftmaxWithCrossEntropyFunctorWithIgnoreIdx { template static void HardLabelSoftmaxWithCrossEntropy( const platform::CUDADeviceContext& ctx, const T* logits_data, - const int64_t* labels_data, T* loss_data, T* softmax_data, int n, int d, - int axis_dim, int ignore_idx) { + const int64_t* labels_data, T* loss_data, T* softmax_data, int64_t n, + int64_t d, int axis_dim, int ignore_idx) { constexpr int kMaxBlockDim = 512; - int block_dim = axis_dim >= kMaxBlockDim - ? kMaxBlockDim - : (1 << static_cast(std::log2(axis_dim))); - int grid_dim = n * d / axis_dim; + int64_t block_dim = axis_dim >= kMaxBlockDim + ? kMaxBlockDim + : (1 << static_cast(std::log2(axis_dim))); + int64_t grid_dim = n * d / axis_dim; auto stream = ctx.stream(); #define CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(BlockDim) \ @@ -372,16 +374,14 @@ static void HardLabelSoftmaxWithCrossEntropy( } template -static void SoftmaxWithCrossEntropyFusedKernel(const T* logits_data, - const T* labels_data, - T* softmax_data, T* loss_data, - int n, int d, int axis_dim, - cudaStream_t stream) { +static void SoftmaxWithCrossEntropyFusedKernel( + const T* logits_data, const T* labels_data, T* softmax_data, T* loss_data, + int64_t n, int64_t d, int axis_dim, cudaStream_t stream) { constexpr int kMaxBlockDim = 512; - int block_dim = axis_dim >= kMaxBlockDim - ? kMaxBlockDim - : (1 << static_cast(std::log2(axis_dim))); - int grid_dim = n * d / axis_dim; + int64_t block_dim = axis_dim >= kMaxBlockDim + ? kMaxBlockDim + : (1 << static_cast(std::log2(axis_dim))); + int64_t grid_dim = n * d / axis_dim; #define CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(BlockDim) \ case BlockDim: \ @@ -430,8 +430,8 @@ class SoftmaxWithCrossEntropyCUDAKernel : public framework::OpKernel { const int axis = CanonicalAxis(context.Attr("axis"), rank); int axis_dim = logits->dims()[axis]; - const int n = SizeToAxis(axis, logits->dims()); - const int d = SizeFromAxis(axis, logits->dims()); + const int64_t n = SizeToAxis(axis, logits->dims()); + const int64_t d = SizeFromAxis(axis, logits->dims()); auto* softmax_data = softmax->mutable_data(context.GetPlace()); auto* loss_data = loss->mutable_data(context.GetPlace()); @@ -500,24 +500,24 @@ class SoftmaxWithCrossEntropyGradCUDAKernel : public framework::OpKernel { const int axis = CanonicalAxis(context.Attr("axis"), rank); int axis_dim = logit_grad->dims()[axis]; - const int n = SizeToAxis(axis, logit_grad->dims()); - const int d = SizeFromAxis(axis, logit_grad->dims()); - const int remain = d / axis_dim; + const int64_t n = SizeToAxis(axis, logit_grad->dims()); + const int64_t d = SizeFromAxis(axis, logit_grad->dims()); + const int64_t remain = d / axis_dim; int block = 512; auto stream = context.cuda_device_context().stream(); auto ignore_index = context.Attr("ignore_index"); if (context.Attr("soft_label")) { - int grid = (n * d + block - 1) / block; + int64_t grid = (n * d + block - 1) / block; const T* label_data = labels->data(); SoftCrossEntropyGradientKernel<<>>( logit_grad_data, loss_grad_data, label_data, n, d, remain); } else { - int grid = (n * remain + block - 1) / block; + int64_t grid = (n * remain + block - 1) / block; const int64_t* label_data = labels->data(); CrossEntropyGrad<<>>( logit_grad_data, label_data, n, d, remain, ignore_index); - int num = n * d; + int64_t num = n * d; grid = (num + block - 1) / block; Scale<<>>(logit_grad_data, loss_grad_data, num, d, remain, label_data, ignore_index); diff --git a/paddle/fluid/platform/cuda_helper.h b/paddle/fluid/platform/cuda_helper.h index 2a1f0b9ac5c..2a055fda4e9 100644 --- a/paddle/fluid/platform/cuda_helper.h +++ b/paddle/fluid/platform/cuda_helper.h @@ -70,11 +70,14 @@ namespace platform { * } * */ -#define CUDA_KERNEL_LOOP(i, num) \ + +#define CUDA_KERNEL_LOOP_TYPE(i, num, index_type) \ int64_t __index__ = blockIdx.x * blockDim.x + threadIdx.x; \ - for (int i = __index__; __index__ < (num); \ + for (index_type i = __index__; __index__ < (num); \ __index__ += blockDim.x * gridDim.x, i = __index__) +#define CUDA_KERNEL_LOOP(i, num) CUDA_KERNEL_LOOP_TYPE(i, num, int) + class CublasHandleHolder { public: CublasHandleHolder(cudaStream_t stream, cublasMath_t math_type) { diff --git a/paddle/fluid/platform/for_range.h b/paddle/fluid/platform/for_range.h index c153e80fe42..d922f5a29e0 100644 --- a/paddle/fluid/platform/for_range.h +++ b/paddle/fluid/platform/for_range.h @@ -48,7 +48,7 @@ __global__ static void ForRangeElemwiseOpGridIsOne(Function func) { } template -__global__ static void ForRangeElemwiseOp(Function func, int limit) { +__global__ static void ForRangeElemwiseOp(Function func, size_t limit) { size_t idx = static_cast(blockIdx.x * blockDim.x + threadIdx.x); if (idx < limit) { func(idx); @@ -58,13 +58,13 @@ __global__ static void ForRangeElemwiseOp(Function func, int limit) { template <> struct ForRange { ForRange(const CUDADeviceContext& dev_ctx, size_t limit) - : dev_ctx_(dev_ctx), limit_(static_cast(limit)) {} + : dev_ctx_(dev_ctx), limit_(static_cast(limit)) {} template inline void operator()(Function func) const { constexpr int num_threads = 1024; - int block_size = limit_ <= num_threads ? limit_ : num_threads; - int grid_size = (limit_ + num_threads - 1) / num_threads; + size_t block_size = limit_ <= num_threads ? limit_ : num_threads; + size_t grid_size = (limit_ + num_threads - 1) / num_threads; if (grid_size == 1) { ForRangeElemwiseOpGridIsOne<<<1, block_size, 0, dev_ctx_.stream()>>>( @@ -76,7 +76,7 @@ struct ForRange { } const CUDADeviceContext& dev_ctx_; - int limit_; + size_t limit_; }; #endif -- GitLab