未验证 提交 16fe11d7 编写于 作者: Z Zhong Hui 提交者: GitHub

fix softmax cross entropy integer overflow (#30590)

[BUG FIX] Fix softmax cross entropy overflow problem.
上级 44ee251f
......@@ -29,16 +29,16 @@ static inline int CanonicalAxis(const int axis, const int rank) {
return axis;
}
static inline int SizeToAxis(const int axis, const framework::DDim dims) {
int size = 1;
static inline size_t SizeToAxis(const int axis, const framework::DDim dims) {
size_t size = 1;
for (int i = 0; i < axis; i++) {
size *= dims[i];
}
return size;
}
static inline int SizeFromAxis(const int axis, const framework::DDim dims) {
int size = 1;
static inline size_t SizeFromAxis(const int axis, const framework::DDim dims) {
size_t size = 1;
for (int i = axis; i < dims.size(); i++) {
size *= dims[i];
}
......
......@@ -22,27 +22,27 @@ using Tensor = framework::Tensor;
namespace {
template <typename T>
__global__ void CrossEntropyGrad(T* logit_grad, const int64_t* labels,
const int n, const int d, const int remain,
const int ignore_index) {
CUDA_KERNEL_LOOP(index, n * remain) {
int idx_n = index / remain;
int idx_remain = index % remain;
int tmp = labels[index];
const int64_t n, const int64_t d,
const int64_t remain, const int ignore_index) {
CUDA_KERNEL_LOOP_TYPE(index, n * remain, int64_t) {
int64_t idx_n = index / remain;
int64_t idx_remain = index % remain;
int64_t tmp = labels[index];
if (ignore_index != tmp) {
int idx = idx_n * d + tmp * remain + idx_remain;
int64_t idx = idx_n * d + tmp * remain + idx_remain;
logit_grad[idx] -= static_cast<T>(1.);
}
}
}
template <typename T>
__global__ void Scale(T* logit_grad, const T* loss_grad, const int num,
const int d, const int remain, const int64_t* labels,
const int ignore_index) {
CUDA_KERNEL_LOOP(index, num) {
int idx_n = index / d;
int idx_remain = index % remain;
int idx_lbl = idx_n * remain + idx_remain;
__global__ void Scale(T* logit_grad, const T* loss_grad, const int64_t num,
const int64_t d, const int64_t remain,
const int64_t* labels, const int ignore_index) {
CUDA_KERNEL_LOOP_TYPE(index, num, int64_t) {
int64_t idx_n = index / d;
int64_t idx_remain = index % remain;
int64_t idx_lbl = idx_n * remain + idx_remain;
if (labels[idx_lbl] == ignore_index) {
logit_grad[index] = static_cast<T>(0.);
} else {
......@@ -54,13 +54,14 @@ __global__ void Scale(T* logit_grad, const T* loss_grad, const int num,
template <typename T>
__global__ void SoftCrossEntropyGradientKernel(T* logit_grad,
const T* loss_grad,
const T* labels, const int n,
const int d, const int remain) {
int ids = blockIdx.x * blockDim.x + threadIdx.x;
const T* labels, const int64_t n,
const int64_t d,
const int64_t remain) {
int64_t ids = blockIdx.x * blockDim.x + threadIdx.x;
if (ids < n * d) {
int idx_n = ids / d;
int idx_remain = ids % remain;
int idx_loss = idx_n * remain + idx_remain;
int64_t idx_n = ids / d;
int64_t idx_remain = ids % remain;
int64_t idx_loss = idx_n * remain + idx_remain;
logit_grad[ids] = loss_grad[idx_loss] * (logit_grad[ids] - labels[ids]);
}
}
......@@ -132,19 +133,19 @@ using BlockReduceTempStorage = typename BlockReduce<T, BlockDim>::TempStorage;
// This kernel is used to calculate the max element of each row
template <typename T, int BlockDim>
static __global__ void RowReductionForMax(const T* logits_data, T* max_data,
int d, int axis_dim) {
int64_t d, int axis_dim) {
__shared__ BlockReduceTempStorage<T, BlockDim> temp_storage;
// logits_data view as [n, axis_dim, remain]
// max_data view as [n, 1, remain]
// blockDim = n * remain, split blockIdx to idx_n and idx_remain
int remain = d / axis_dim;
int idx_n = blockIdx.x / remain;
int idx_remain = blockIdx.x % remain;
int beg_idx = idx_n * d + threadIdx.x * remain + idx_remain;
int end_idx = (idx_n + 1) * d;
int64_t remain = d / axis_dim;
int64_t idx_n = blockIdx.x / remain;
int64_t idx_remain = blockIdx.x % remain;
int64_t beg_idx = idx_n * d + threadIdx.x * remain + idx_remain;
int64_t end_idx = (idx_n + 1) * d;
int step = BlockDim * remain;
int64_t step = BlockDim * remain;
T cur_max = logits_data[beg_idx];
beg_idx += step;
while (beg_idx < end_idx) {
......@@ -162,21 +163,21 @@ static __global__ void RowReductionForMax(const T* logits_data, T* max_data,
// Make sure that BlockDim <= axis_dim
template <typename T, int BlockDim, bool CalculateLogSoftmax = false>
static __global__ void RowReductionForDiffMaxSum(const T* logits_data,
T* max_data, T* softmax, int d,
int axis_dim) {
T* max_data, T* softmax,
int64_t d, int axis_dim) {
__shared__ BlockReduceTempStorage<T, BlockDim> temp_storage;
// logits, softmax data view as [n, axis_dim, remain]
// max_data view as [n, 1, remain]
// blockDim = n * remain, split blockIdx to idx_n and idx_remain
int remain = d / axis_dim;
int idx_n = blockIdx.x / remain;
int idx_remain = blockIdx.x % remain;
int beg_idx = idx_n * d + threadIdx.x * remain + idx_remain;
int end_idx = (idx_n + 1) * d;
int64_t remain = d / axis_dim;
int64_t idx_n = blockIdx.x / remain;
int64_t idx_remain = blockIdx.x % remain;
int64_t beg_idx = idx_n * d + threadIdx.x * remain + idx_remain;
int64_t end_idx = (idx_n + 1) * d;
auto block_max = max_data[blockIdx.x];
int step = BlockDim * remain;
int64_t step = BlockDim * remain;
// In numeric stable mode softmax_with_loss, we calc loss with
// tmp_i_j = x_i_j - max_i - logDiffMaxSum_i, instead of
......@@ -216,25 +217,25 @@ static __global__ void RowReductionForDiffMaxSum(const T* logits_data,
// Make sure that BlockDim <= axis_dim
template <typename T, int BlockDim>
static __global__ void RowReductionForSoftmaxAndCrossEntropy(
const T* logits_data, const T* labels_data, T* loss_data, T* softmax, int d,
int axis_dim) {
const T* logits_data, const T* labels_data, T* loss_data, T* softmax,
int64_t d, int axis_dim) {
__shared__ BlockReduceTempStorage<T, BlockDim> temp_storage;
// logits, softmax, labels data view as [n, axis_dim, remain]
// loss_data view as [n, 1, remain]
// blockDim = n * remain, split blockIdx to idx_n and idx_remain
int remain = d / axis_dim;
int idx_n = blockIdx.x / remain;
int idx_remain = blockIdx.x % remain;
int beg_idx = idx_n * d + threadIdx.x * remain + idx_remain;
int end_idx = (idx_n + 1) * d;
int64_t remain = d / axis_dim;
int64_t idx_n = blockIdx.x / remain;
int64_t idx_remain = blockIdx.x % remain;
int64_t beg_idx = idx_n * d + threadIdx.x * remain + idx_remain;
int64_t end_idx = (idx_n + 1) * d;
// log_diff_max_sum shares memory with loss
auto block_log_diff_max_sum = loss_data[blockIdx.x];
auto tmp = softmax[beg_idx] - block_log_diff_max_sum;
softmax[beg_idx] = exp_on_device(tmp);
auto loss = -labels_data[beg_idx] * tmp;
int step = BlockDim * remain;
int64_t step = BlockDim * remain;
beg_idx += step;
while (beg_idx < end_idx) {
tmp = softmax[beg_idx] - block_log_diff_max_sum;
......@@ -251,21 +252,22 @@ template <typename T>
struct HardLabelSoftmaxWithCrossEntropyFunctor {
public:
HardLabelSoftmaxWithCrossEntropyFunctor(const int64_t* labels, T* loss,
T* log_softmax, int d, int axis_dim)
T* log_softmax, int64_t d,
int axis_dim)
: labels_(labels),
loss_(loss),
log_softmax_(log_softmax),
d_(d),
axis_dim_(axis_dim) {}
__device__ void operator()(int idx) const {
__device__ void operator()(int64_t idx) const {
// logits view as [n, axis_dim, remain], where d = axis_dim * remain
int remain = d_ / axis_dim_;
int idx_n = idx / d_;
int idx_axis = (idx % d_) / remain;
int idx_remain = idx % remain;
int64_t remain = d_ / axis_dim_;
int64_t idx_n = idx / d_;
int64_t idx_axis = (idx % d_) / remain;
int64_t idx_remain = idx % remain;
// labels, loss view as [n, remain]
int idx_lbl = idx_n * remain + idx_remain;
int64_t idx_lbl = idx_n * remain + idx_remain;
// It also would ignore labels not in range(class_num).
if (idx_axis != labels_[idx_lbl]) {
log_softmax_[idx] = exp_on_device(log_softmax_[idx]);
......@@ -280,7 +282,7 @@ struct HardLabelSoftmaxWithCrossEntropyFunctor {
const int64_t* labels_;
T* loss_;
T* log_softmax_;
int d_;
int64_t d_;
int axis_dim_;
};
......@@ -289,7 +291,7 @@ struct HardLabelSoftmaxWithCrossEntropyFunctorWithIgnoreIdx {
public:
HardLabelSoftmaxWithCrossEntropyFunctorWithIgnoreIdx(const int64_t* labels,
T* loss, T* log_softmax,
int d, int axis_dim,
int64_t d, int axis_dim,
int ignore_idx)
: labels_(labels),
loss_(loss),
......@@ -298,14 +300,14 @@ struct HardLabelSoftmaxWithCrossEntropyFunctorWithIgnoreIdx {
axis_dim_(axis_dim),
ignore_idx_(ignore_idx) {}
__device__ void operator()(int idx) const {
__device__ void operator()(int64_t idx) const {
// logits view as [n, axis_dim, remain], where d = axis_dim * remain
int remain = d_ / axis_dim_;
int idx_n = idx / d_;
int idx_axis = (idx % d_) / remain;
int idx_remain = idx % remain;
int64_t remain = d_ / axis_dim_;
int64_t idx_n = idx / d_;
int64_t idx_axis = (idx % d_) / remain;
int64_t idx_remain = idx % remain;
// labels, loss view as [n, remain]
int idx_lbl = idx_n * remain + idx_remain;
int64_t idx_lbl = idx_n * remain + idx_remain;
if (idx_axis != labels_[idx_lbl] || idx_axis == ignore_idx_) {
log_softmax_[idx] = exp_on_device(log_softmax_[idx]);
} else {
......@@ -319,7 +321,7 @@ struct HardLabelSoftmaxWithCrossEntropyFunctorWithIgnoreIdx {
const int64_t* labels_;
T* loss_;
T* log_softmax_;
int d_;
int64_t d_;
int axis_dim_;
int ignore_idx_;
};
......@@ -327,13 +329,13 @@ struct HardLabelSoftmaxWithCrossEntropyFunctorWithIgnoreIdx {
template <typename T>
static void HardLabelSoftmaxWithCrossEntropy(
const platform::CUDADeviceContext& ctx, const T* logits_data,
const int64_t* labels_data, T* loss_data, T* softmax_data, int n, int d,
int axis_dim, int ignore_idx) {
const int64_t* labels_data, T* loss_data, T* softmax_data, int64_t n,
int64_t d, int axis_dim, int ignore_idx) {
constexpr int kMaxBlockDim = 512;
int block_dim = axis_dim >= kMaxBlockDim
? kMaxBlockDim
: (1 << static_cast<int>(std::log2(axis_dim)));
int grid_dim = n * d / axis_dim;
int64_t block_dim = axis_dim >= kMaxBlockDim
? kMaxBlockDim
: (1 << static_cast<int>(std::log2(axis_dim)));
int64_t grid_dim = n * d / axis_dim;
auto stream = ctx.stream();
#define CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(BlockDim) \
......@@ -372,16 +374,14 @@ static void HardLabelSoftmaxWithCrossEntropy(
}
template <typename T>
static void SoftmaxWithCrossEntropyFusedKernel(const T* logits_data,
const T* labels_data,
T* softmax_data, T* loss_data,
int n, int d, int axis_dim,
cudaStream_t stream) {
static void SoftmaxWithCrossEntropyFusedKernel(
const T* logits_data, const T* labels_data, T* softmax_data, T* loss_data,
int64_t n, int64_t d, int axis_dim, cudaStream_t stream) {
constexpr int kMaxBlockDim = 512;
int block_dim = axis_dim >= kMaxBlockDim
? kMaxBlockDim
: (1 << static_cast<int>(std::log2(axis_dim)));
int grid_dim = n * d / axis_dim;
int64_t block_dim = axis_dim >= kMaxBlockDim
? kMaxBlockDim
: (1 << static_cast<int>(std::log2(axis_dim)));
int64_t grid_dim = n * d / axis_dim;
#define CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(BlockDim) \
case BlockDim: \
......@@ -430,8 +430,8 @@ class SoftmaxWithCrossEntropyCUDAKernel : public framework::OpKernel<T> {
const int axis = CanonicalAxis(context.Attr<int>("axis"), rank);
int axis_dim = logits->dims()[axis];
const int n = SizeToAxis(axis, logits->dims());
const int d = SizeFromAxis(axis, logits->dims());
const int64_t n = SizeToAxis(axis, logits->dims());
const int64_t d = SizeFromAxis(axis, logits->dims());
auto* softmax_data = softmax->mutable_data<T>(context.GetPlace());
auto* loss_data = loss->mutable_data<T>(context.GetPlace());
......@@ -500,24 +500,24 @@ class SoftmaxWithCrossEntropyGradCUDAKernel : public framework::OpKernel<T> {
const int axis = CanonicalAxis(context.Attr<int>("axis"), rank);
int axis_dim = logit_grad->dims()[axis];
const int n = SizeToAxis(axis, logit_grad->dims());
const int d = SizeFromAxis(axis, logit_grad->dims());
const int remain = d / axis_dim;
const int64_t n = SizeToAxis(axis, logit_grad->dims());
const int64_t d = SizeFromAxis(axis, logit_grad->dims());
const int64_t remain = d / axis_dim;
int block = 512;
auto stream = context.cuda_device_context().stream();
auto ignore_index = context.Attr<int>("ignore_index");
if (context.Attr<bool>("soft_label")) {
int grid = (n * d + block - 1) / block;
int64_t grid = (n * d + block - 1) / block;
const T* label_data = labels->data<T>();
SoftCrossEntropyGradientKernel<T><<<grid, block, 0, stream>>>(
logit_grad_data, loss_grad_data, label_data, n, d, remain);
} else {
int grid = (n * remain + block - 1) / block;
int64_t grid = (n * remain + block - 1) / block;
const int64_t* label_data = labels->data<int64_t>();
CrossEntropyGrad<T><<<grid, block, 0, stream>>>(
logit_grad_data, label_data, n, d, remain, ignore_index);
int num = n * d;
int64_t num = n * d;
grid = (num + block - 1) / block;
Scale<T><<<grid, block, 0, stream>>>(logit_grad_data, loss_grad_data, num,
d, remain, label_data, ignore_index);
......
......@@ -75,11 +75,14 @@ namespace platform {
* }
*
*/
#define CUDA_KERNEL_LOOP(i, num) \
#define CUDA_KERNEL_LOOP_TYPE(i, num, index_type) \
int64_t __index__ = blockIdx.x * blockDim.x + threadIdx.x; \
for (int i = __index__; __index__ < (num); \
for (index_type i = __index__; __index__ < (num); \
__index__ += blockDim.x * gridDim.x, i = __index__)
#define CUDA_KERNEL_LOOP(i, num) CUDA_KERNEL_LOOP_TYPE(i, num, int)
class CublasHandleHolder {
public:
#ifdef PADDLE_WITH_HIP
......
......@@ -48,7 +48,7 @@ __global__ static void ForRangeElemwiseOpGridIsOne(Function func) {
}
template <typename Function>
__global__ static void ForRangeElemwiseOp(Function func, int limit) {
__global__ static void ForRangeElemwiseOp(Function func, size_t limit) {
size_t idx = static_cast<size_t>(blockIdx.x * blockDim.x + threadIdx.x);
if (idx < limit) {
func(idx);
......@@ -58,13 +58,13 @@ __global__ static void ForRangeElemwiseOp(Function func, int limit) {
template <>
struct ForRange<CUDADeviceContext> {
ForRange(const CUDADeviceContext& dev_ctx, size_t limit)
: dev_ctx_(dev_ctx), limit_(static_cast<int>(limit)) {}
: dev_ctx_(dev_ctx), limit_(static_cast<size_t>(limit)) {}
template <typename Function>
inline void operator()(Function func) const {
constexpr int num_threads = 1024;
int block_size = limit_ <= num_threads ? limit_ : num_threads;
int grid_size = (limit_ + num_threads - 1) / num_threads;
size_t block_size = limit_ <= num_threads ? limit_ : num_threads;
size_t grid_size = (limit_ + num_threads - 1) / num_threads;
if (grid_size == 1) {
ForRangeElemwiseOpGridIsOne<<<1, block_size, 0, dev_ctx_.stream()>>>(
......@@ -76,7 +76,7 @@ struct ForRange<CUDADeviceContext> {
}
const CUDADeviceContext& dev_ctx_;
int limit_;
size_t limit_;
};
#endif
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册