未验证 提交 30a2e7f0 编写于 作者: Z Zhong Hui 提交者: GitHub

[cherry-pick] Fix softmax cross entropy integer overflow. (#30590) (#31134)

[BUG FIX] Fix softmax cross entropy overflow problem.
上级 3a72408f
...@@ -29,16 +29,16 @@ static inline int CanonicalAxis(const int axis, const int rank) { ...@@ -29,16 +29,16 @@ static inline int CanonicalAxis(const int axis, const int rank) {
return axis; return axis;
} }
static inline int SizeToAxis(const int axis, const framework::DDim dims) { static inline size_t SizeToAxis(const int axis, const framework::DDim dims) {
int size = 1; size_t size = 1;
for (int i = 0; i < axis; i++) { for (int i = 0; i < axis; i++) {
size *= dims[i]; size *= dims[i];
} }
return size; return size;
} }
static inline int SizeFromAxis(const int axis, const framework::DDim dims) { static inline size_t SizeFromAxis(const int axis, const framework::DDim dims) {
int size = 1; size_t size = 1;
for (int i = axis; i < dims.size(); i++) { for (int i = axis; i < dims.size(); i++) {
size *= dims[i]; size *= dims[i];
} }
......
...@@ -22,27 +22,27 @@ using Tensor = framework::Tensor; ...@@ -22,27 +22,27 @@ using Tensor = framework::Tensor;
namespace { namespace {
template <typename T> template <typename T>
__global__ void CrossEntropyGrad(T* logit_grad, const int64_t* labels, __global__ void CrossEntropyGrad(T* logit_grad, const int64_t* labels,
const int n, const int d, const int remain, const int64_t n, const int64_t d,
const int ignore_index) { const int64_t remain, const int ignore_index) {
CUDA_KERNEL_LOOP(index, n * remain) { CUDA_KERNEL_LOOP_TYPE(index, n * remain, int64_t) {
int idx_n = index / remain; int64_t idx_n = index / remain;
int idx_remain = index % remain; int64_t idx_remain = index % remain;
int tmp = labels[index]; int64_t tmp = labels[index];
if (ignore_index != tmp) { if (ignore_index != tmp) {
int idx = idx_n * d + tmp * remain + idx_remain; int64_t idx = idx_n * d + tmp * remain + idx_remain;
logit_grad[idx] -= static_cast<T>(1.); logit_grad[idx] -= static_cast<T>(1.);
} }
} }
} }
template <typename T> template <typename T>
__global__ void Scale(T* logit_grad, const T* loss_grad, const int num, __global__ void Scale(T* logit_grad, const T* loss_grad, const int64_t num,
const int d, const int remain, const int64_t* labels, const int64_t d, const int64_t remain,
const int ignore_index) { const int64_t* labels, const int ignore_index) {
CUDA_KERNEL_LOOP(index, num) { CUDA_KERNEL_LOOP_TYPE(index, num, int64_t) {
int idx_n = index / d; int64_t idx_n = index / d;
int idx_remain = index % remain; int64_t idx_remain = index % remain;
int idx_lbl = idx_n * remain + idx_remain; int64_t idx_lbl = idx_n * remain + idx_remain;
if (labels[idx_lbl] == ignore_index) { if (labels[idx_lbl] == ignore_index) {
logit_grad[index] = static_cast<T>(0.); logit_grad[index] = static_cast<T>(0.);
} else { } else {
...@@ -54,13 +54,14 @@ __global__ void Scale(T* logit_grad, const T* loss_grad, const int num, ...@@ -54,13 +54,14 @@ __global__ void Scale(T* logit_grad, const T* loss_grad, const int num,
template <typename T> template <typename T>
__global__ void SoftCrossEntropyGradientKernel(T* logit_grad, __global__ void SoftCrossEntropyGradientKernel(T* logit_grad,
const T* loss_grad, const T* loss_grad,
const T* labels, const int n, const T* labels, const int64_t n,
const int d, const int remain) { const int64_t d,
int ids = blockIdx.x * blockDim.x + threadIdx.x; const int64_t remain) {
int64_t ids = blockIdx.x * blockDim.x + threadIdx.x;
if (ids < n * d) { if (ids < n * d) {
int idx_n = ids / d; int64_t idx_n = ids / d;
int idx_remain = ids % remain; int64_t idx_remain = ids % remain;
int idx_loss = idx_n * remain + idx_remain; int64_t idx_loss = idx_n * remain + idx_remain;
logit_grad[ids] = loss_grad[idx_loss] * (logit_grad[ids] - labels[ids]); logit_grad[ids] = loss_grad[idx_loss] * (logit_grad[ids] - labels[ids]);
} }
} }
...@@ -132,19 +133,19 @@ using BlockReduceTempStorage = typename BlockReduce<T, BlockDim>::TempStorage; ...@@ -132,19 +133,19 @@ using BlockReduceTempStorage = typename BlockReduce<T, BlockDim>::TempStorage;
// This kernel is used to calculate the max element of each row // This kernel is used to calculate the max element of each row
template <typename T, int BlockDim> template <typename T, int BlockDim>
static __global__ void RowReductionForMax(const T* logits_data, T* max_data, static __global__ void RowReductionForMax(const T* logits_data, T* max_data,
int d, int axis_dim) { int64_t d, int axis_dim) {
__shared__ BlockReduceTempStorage<T, BlockDim> temp_storage; __shared__ BlockReduceTempStorage<T, BlockDim> temp_storage;
// logits_data view as [n, axis_dim, remain] // logits_data view as [n, axis_dim, remain]
// max_data view as [n, 1, remain] // max_data view as [n, 1, remain]
// blockDim = n * remain, split blockIdx to idx_n and idx_remain // blockDim = n * remain, split blockIdx to idx_n and idx_remain
int remain = d / axis_dim; int64_t remain = d / axis_dim;
int idx_n = blockIdx.x / remain; int64_t idx_n = blockIdx.x / remain;
int idx_remain = blockIdx.x % remain; int64_t idx_remain = blockIdx.x % remain;
int beg_idx = idx_n * d + threadIdx.x * remain + idx_remain; int64_t beg_idx = idx_n * d + threadIdx.x * remain + idx_remain;
int end_idx = (idx_n + 1) * d; int64_t end_idx = (idx_n + 1) * d;
int step = BlockDim * remain; int64_t step = BlockDim * remain;
T cur_max = logits_data[beg_idx]; T cur_max = logits_data[beg_idx];
beg_idx += step; beg_idx += step;
while (beg_idx < end_idx) { while (beg_idx < end_idx) {
...@@ -162,21 +163,21 @@ static __global__ void RowReductionForMax(const T* logits_data, T* max_data, ...@@ -162,21 +163,21 @@ static __global__ void RowReductionForMax(const T* logits_data, T* max_data,
// Make sure that BlockDim <= axis_dim // Make sure that BlockDim <= axis_dim
template <typename T, int BlockDim, bool CalculateLogSoftmax = false> template <typename T, int BlockDim, bool CalculateLogSoftmax = false>
static __global__ void RowReductionForDiffMaxSum(const T* logits_data, static __global__ void RowReductionForDiffMaxSum(const T* logits_data,
T* max_data, T* softmax, int d, T* max_data, T* softmax,
int axis_dim) { int64_t d, int axis_dim) {
__shared__ BlockReduceTempStorage<T, BlockDim> temp_storage; __shared__ BlockReduceTempStorage<T, BlockDim> temp_storage;
// logits, softmax data view as [n, axis_dim, remain] // logits, softmax data view as [n, axis_dim, remain]
// max_data view as [n, 1, remain] // max_data view as [n, 1, remain]
// blockDim = n * remain, split blockIdx to idx_n and idx_remain // blockDim = n * remain, split blockIdx to idx_n and idx_remain
int remain = d / axis_dim; int64_t remain = d / axis_dim;
int idx_n = blockIdx.x / remain; int64_t idx_n = blockIdx.x / remain;
int idx_remain = blockIdx.x % remain; int64_t idx_remain = blockIdx.x % remain;
int beg_idx = idx_n * d + threadIdx.x * remain + idx_remain; int64_t beg_idx = idx_n * d + threadIdx.x * remain + idx_remain;
int end_idx = (idx_n + 1) * d; int64_t end_idx = (idx_n + 1) * d;
auto block_max = max_data[blockIdx.x]; auto block_max = max_data[blockIdx.x];
int step = BlockDim * remain; int64_t step = BlockDim * remain;
// In numeric stable mode softmax_with_loss, we calc loss with // In numeric stable mode softmax_with_loss, we calc loss with
// tmp_i_j = x_i_j - max_i - logDiffMaxSum_i, instead of // tmp_i_j = x_i_j - max_i - logDiffMaxSum_i, instead of
...@@ -216,25 +217,25 @@ static __global__ void RowReductionForDiffMaxSum(const T* logits_data, ...@@ -216,25 +217,25 @@ static __global__ void RowReductionForDiffMaxSum(const T* logits_data,
// Make sure that BlockDim <= axis_dim // Make sure that BlockDim <= axis_dim
template <typename T, int BlockDim> template <typename T, int BlockDim>
static __global__ void RowReductionForSoftmaxAndCrossEntropy( static __global__ void RowReductionForSoftmaxAndCrossEntropy(
const T* logits_data, const T* labels_data, T* loss_data, T* softmax, int d, const T* logits_data, const T* labels_data, T* loss_data, T* softmax,
int axis_dim) { int64_t d, int axis_dim) {
__shared__ BlockReduceTempStorage<T, BlockDim> temp_storage; __shared__ BlockReduceTempStorage<T, BlockDim> temp_storage;
// logits, softmax, labels data view as [n, axis_dim, remain] // logits, softmax, labels data view as [n, axis_dim, remain]
// loss_data view as [n, 1, remain] // loss_data view as [n, 1, remain]
// blockDim = n * remain, split blockIdx to idx_n and idx_remain // blockDim = n * remain, split blockIdx to idx_n and idx_remain
int remain = d / axis_dim; int64_t remain = d / axis_dim;
int idx_n = blockIdx.x / remain; int64_t idx_n = blockIdx.x / remain;
int idx_remain = blockIdx.x % remain; int64_t idx_remain = blockIdx.x % remain;
int beg_idx = idx_n * d + threadIdx.x * remain + idx_remain; int64_t beg_idx = idx_n * d + threadIdx.x * remain + idx_remain;
int end_idx = (idx_n + 1) * d; int64_t end_idx = (idx_n + 1) * d;
// log_diff_max_sum shares memory with loss // log_diff_max_sum shares memory with loss
auto block_log_diff_max_sum = loss_data[blockIdx.x]; auto block_log_diff_max_sum = loss_data[blockIdx.x];
auto tmp = softmax[beg_idx] - block_log_diff_max_sum; auto tmp = softmax[beg_idx] - block_log_diff_max_sum;
softmax[beg_idx] = exp_on_device(tmp); softmax[beg_idx] = exp_on_device(tmp);
auto loss = -labels_data[beg_idx] * tmp; auto loss = -labels_data[beg_idx] * tmp;
int step = BlockDim * remain; int64_t step = BlockDim * remain;
beg_idx += step; beg_idx += step;
while (beg_idx < end_idx) { while (beg_idx < end_idx) {
tmp = softmax[beg_idx] - block_log_diff_max_sum; tmp = softmax[beg_idx] - block_log_diff_max_sum;
...@@ -251,21 +252,22 @@ template <typename T> ...@@ -251,21 +252,22 @@ template <typename T>
struct HardLabelSoftmaxWithCrossEntropyFunctor { struct HardLabelSoftmaxWithCrossEntropyFunctor {
public: public:
HardLabelSoftmaxWithCrossEntropyFunctor(const int64_t* labels, T* loss, HardLabelSoftmaxWithCrossEntropyFunctor(const int64_t* labels, T* loss,
T* log_softmax, int d, int axis_dim) T* log_softmax, int64_t d,
int axis_dim)
: labels_(labels), : labels_(labels),
loss_(loss), loss_(loss),
log_softmax_(log_softmax), log_softmax_(log_softmax),
d_(d), d_(d),
axis_dim_(axis_dim) {} axis_dim_(axis_dim) {}
__device__ void operator()(int idx) const { __device__ void operator()(int64_t idx) const {
// logits view as [n, axis_dim, remain], where d = axis_dim * remain // logits view as [n, axis_dim, remain], where d = axis_dim * remain
int remain = d_ / axis_dim_; int64_t remain = d_ / axis_dim_;
int idx_n = idx / d_; int64_t idx_n = idx / d_;
int idx_axis = (idx % d_) / remain; int64_t idx_axis = (idx % d_) / remain;
int idx_remain = idx % remain; int64_t idx_remain = idx % remain;
// labels, loss view as [n, remain] // labels, loss view as [n, remain]
int idx_lbl = idx_n * remain + idx_remain; int64_t idx_lbl = idx_n * remain + idx_remain;
// It also would ignore labels not in range(class_num). // It also would ignore labels not in range(class_num).
if (idx_axis != labels_[idx_lbl]) { if (idx_axis != labels_[idx_lbl]) {
log_softmax_[idx] = exp_on_device(log_softmax_[idx]); log_softmax_[idx] = exp_on_device(log_softmax_[idx]);
...@@ -280,7 +282,7 @@ struct HardLabelSoftmaxWithCrossEntropyFunctor { ...@@ -280,7 +282,7 @@ struct HardLabelSoftmaxWithCrossEntropyFunctor {
const int64_t* labels_; const int64_t* labels_;
T* loss_; T* loss_;
T* log_softmax_; T* log_softmax_;
int d_; int64_t d_;
int axis_dim_; int axis_dim_;
}; };
...@@ -289,7 +291,7 @@ struct HardLabelSoftmaxWithCrossEntropyFunctorWithIgnoreIdx { ...@@ -289,7 +291,7 @@ struct HardLabelSoftmaxWithCrossEntropyFunctorWithIgnoreIdx {
public: public:
HardLabelSoftmaxWithCrossEntropyFunctorWithIgnoreIdx(const int64_t* labels, HardLabelSoftmaxWithCrossEntropyFunctorWithIgnoreIdx(const int64_t* labels,
T* loss, T* log_softmax, T* loss, T* log_softmax,
int d, int axis_dim, int64_t d, int axis_dim,
int ignore_idx) int ignore_idx)
: labels_(labels), : labels_(labels),
loss_(loss), loss_(loss),
...@@ -298,14 +300,14 @@ struct HardLabelSoftmaxWithCrossEntropyFunctorWithIgnoreIdx { ...@@ -298,14 +300,14 @@ struct HardLabelSoftmaxWithCrossEntropyFunctorWithIgnoreIdx {
axis_dim_(axis_dim), axis_dim_(axis_dim),
ignore_idx_(ignore_idx) {} ignore_idx_(ignore_idx) {}
__device__ void operator()(int idx) const { __device__ void operator()(int64_t idx) const {
// logits view as [n, axis_dim, remain], where d = axis_dim * remain // logits view as [n, axis_dim, remain], where d = axis_dim * remain
int remain = d_ / axis_dim_; int64_t remain = d_ / axis_dim_;
int idx_n = idx / d_; int64_t idx_n = idx / d_;
int idx_axis = (idx % d_) / remain; int64_t idx_axis = (idx % d_) / remain;
int idx_remain = idx % remain; int64_t idx_remain = idx % remain;
// labels, loss view as [n, remain] // labels, loss view as [n, remain]
int idx_lbl = idx_n * remain + idx_remain; int64_t idx_lbl = idx_n * remain + idx_remain;
if (idx_axis != labels_[idx_lbl] || idx_axis == ignore_idx_) { if (idx_axis != labels_[idx_lbl] || idx_axis == ignore_idx_) {
log_softmax_[idx] = exp_on_device(log_softmax_[idx]); log_softmax_[idx] = exp_on_device(log_softmax_[idx]);
} else { } else {
...@@ -319,7 +321,7 @@ struct HardLabelSoftmaxWithCrossEntropyFunctorWithIgnoreIdx { ...@@ -319,7 +321,7 @@ struct HardLabelSoftmaxWithCrossEntropyFunctorWithIgnoreIdx {
const int64_t* labels_; const int64_t* labels_;
T* loss_; T* loss_;
T* log_softmax_; T* log_softmax_;
int d_; int64_t d_;
int axis_dim_; int axis_dim_;
int ignore_idx_; int ignore_idx_;
}; };
...@@ -327,13 +329,13 @@ struct HardLabelSoftmaxWithCrossEntropyFunctorWithIgnoreIdx { ...@@ -327,13 +329,13 @@ struct HardLabelSoftmaxWithCrossEntropyFunctorWithIgnoreIdx {
template <typename T> template <typename T>
static void HardLabelSoftmaxWithCrossEntropy( static void HardLabelSoftmaxWithCrossEntropy(
const platform::CUDADeviceContext& ctx, const T* logits_data, const platform::CUDADeviceContext& ctx, const T* logits_data,
const int64_t* labels_data, T* loss_data, T* softmax_data, int n, int d, const int64_t* labels_data, T* loss_data, T* softmax_data, int64_t n,
int axis_dim, int ignore_idx) { int64_t d, int axis_dim, int ignore_idx) {
constexpr int kMaxBlockDim = 512; constexpr int kMaxBlockDim = 512;
int block_dim = axis_dim >= kMaxBlockDim int64_t block_dim = axis_dim >= kMaxBlockDim
? kMaxBlockDim ? kMaxBlockDim
: (1 << static_cast<int>(std::log2(axis_dim))); : (1 << static_cast<int>(std::log2(axis_dim)));
int grid_dim = n * d / axis_dim; int64_t grid_dim = n * d / axis_dim;
auto stream = ctx.stream(); auto stream = ctx.stream();
#define CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(BlockDim) \ #define CALL_HARD_LABEL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(BlockDim) \
...@@ -372,16 +374,14 @@ static void HardLabelSoftmaxWithCrossEntropy( ...@@ -372,16 +374,14 @@ static void HardLabelSoftmaxWithCrossEntropy(
} }
template <typename T> template <typename T>
static void SoftmaxWithCrossEntropyFusedKernel(const T* logits_data, static void SoftmaxWithCrossEntropyFusedKernel(
const T* labels_data, const T* logits_data, const T* labels_data, T* softmax_data, T* loss_data,
T* softmax_data, T* loss_data, int64_t n, int64_t d, int axis_dim, cudaStream_t stream) {
int n, int d, int axis_dim,
cudaStream_t stream) {
constexpr int kMaxBlockDim = 512; constexpr int kMaxBlockDim = 512;
int block_dim = axis_dim >= kMaxBlockDim int64_t block_dim = axis_dim >= kMaxBlockDim
? kMaxBlockDim ? kMaxBlockDim
: (1 << static_cast<int>(std::log2(axis_dim))); : (1 << static_cast<int>(std::log2(axis_dim)));
int grid_dim = n * d / axis_dim; int64_t grid_dim = n * d / axis_dim;
#define CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(BlockDim) \ #define CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(BlockDim) \
case BlockDim: \ case BlockDim: \
...@@ -430,8 +430,8 @@ class SoftmaxWithCrossEntropyCUDAKernel : public framework::OpKernel<T> { ...@@ -430,8 +430,8 @@ class SoftmaxWithCrossEntropyCUDAKernel : public framework::OpKernel<T> {
const int axis = CanonicalAxis(context.Attr<int>("axis"), rank); const int axis = CanonicalAxis(context.Attr<int>("axis"), rank);
int axis_dim = logits->dims()[axis]; int axis_dim = logits->dims()[axis];
const int n = SizeToAxis(axis, logits->dims()); const int64_t n = SizeToAxis(axis, logits->dims());
const int d = SizeFromAxis(axis, logits->dims()); const int64_t d = SizeFromAxis(axis, logits->dims());
auto* softmax_data = softmax->mutable_data<T>(context.GetPlace()); auto* softmax_data = softmax->mutable_data<T>(context.GetPlace());
auto* loss_data = loss->mutable_data<T>(context.GetPlace()); auto* loss_data = loss->mutable_data<T>(context.GetPlace());
...@@ -500,24 +500,24 @@ class SoftmaxWithCrossEntropyGradCUDAKernel : public framework::OpKernel<T> { ...@@ -500,24 +500,24 @@ class SoftmaxWithCrossEntropyGradCUDAKernel : public framework::OpKernel<T> {
const int axis = CanonicalAxis(context.Attr<int>("axis"), rank); const int axis = CanonicalAxis(context.Attr<int>("axis"), rank);
int axis_dim = logit_grad->dims()[axis]; int axis_dim = logit_grad->dims()[axis];
const int n = SizeToAxis(axis, logit_grad->dims()); const int64_t n = SizeToAxis(axis, logit_grad->dims());
const int d = SizeFromAxis(axis, logit_grad->dims()); const int64_t d = SizeFromAxis(axis, logit_grad->dims());
const int remain = d / axis_dim; const int64_t remain = d / axis_dim;
int block = 512; int block = 512;
auto stream = context.cuda_device_context().stream(); auto stream = context.cuda_device_context().stream();
auto ignore_index = context.Attr<int>("ignore_index"); auto ignore_index = context.Attr<int>("ignore_index");
if (context.Attr<bool>("soft_label")) { if (context.Attr<bool>("soft_label")) {
int grid = (n * d + block - 1) / block; int64_t grid = (n * d + block - 1) / block;
const T* label_data = labels->data<T>(); const T* label_data = labels->data<T>();
SoftCrossEntropyGradientKernel<T><<<grid, block, 0, stream>>>( SoftCrossEntropyGradientKernel<T><<<grid, block, 0, stream>>>(
logit_grad_data, loss_grad_data, label_data, n, d, remain); logit_grad_data, loss_grad_data, label_data, n, d, remain);
} else { } else {
int grid = (n * remain + block - 1) / block; int64_t grid = (n * remain + block - 1) / block;
const int64_t* label_data = labels->data<int64_t>(); const int64_t* label_data = labels->data<int64_t>();
CrossEntropyGrad<T><<<grid, block, 0, stream>>>( CrossEntropyGrad<T><<<grid, block, 0, stream>>>(
logit_grad_data, label_data, n, d, remain, ignore_index); logit_grad_data, label_data, n, d, remain, ignore_index);
int num = n * d; int64_t num = n * d;
grid = (num + block - 1) / block; grid = (num + block - 1) / block;
Scale<T><<<grid, block, 0, stream>>>(logit_grad_data, loss_grad_data, num, Scale<T><<<grid, block, 0, stream>>>(logit_grad_data, loss_grad_data, num,
d, remain, label_data, ignore_index); d, remain, label_data, ignore_index);
......
...@@ -70,11 +70,14 @@ namespace platform { ...@@ -70,11 +70,14 @@ namespace platform {
* } * }
* *
*/ */
#define CUDA_KERNEL_LOOP(i, num) \
#define CUDA_KERNEL_LOOP_TYPE(i, num, index_type) \
int64_t __index__ = blockIdx.x * blockDim.x + threadIdx.x; \ int64_t __index__ = blockIdx.x * blockDim.x + threadIdx.x; \
for (int i = __index__; __index__ < (num); \ for (index_type i = __index__; __index__ < (num); \
__index__ += blockDim.x * gridDim.x, i = __index__) __index__ += blockDim.x * gridDim.x, i = __index__)
#define CUDA_KERNEL_LOOP(i, num) CUDA_KERNEL_LOOP_TYPE(i, num, int)
class CublasHandleHolder { class CublasHandleHolder {
public: public:
CublasHandleHolder(cudaStream_t stream, cublasMath_t math_type) { CublasHandleHolder(cudaStream_t stream, cublasMath_t math_type) {
......
...@@ -48,7 +48,7 @@ __global__ static void ForRangeElemwiseOpGridIsOne(Function func) { ...@@ -48,7 +48,7 @@ __global__ static void ForRangeElemwiseOpGridIsOne(Function func) {
} }
template <typename Function> template <typename Function>
__global__ static void ForRangeElemwiseOp(Function func, int limit) { __global__ static void ForRangeElemwiseOp(Function func, size_t limit) {
size_t idx = static_cast<size_t>(blockIdx.x * blockDim.x + threadIdx.x); size_t idx = static_cast<size_t>(blockIdx.x * blockDim.x + threadIdx.x);
if (idx < limit) { if (idx < limit) {
func(idx); func(idx);
...@@ -58,13 +58,13 @@ __global__ static void ForRangeElemwiseOp(Function func, int limit) { ...@@ -58,13 +58,13 @@ __global__ static void ForRangeElemwiseOp(Function func, int limit) {
template <> template <>
struct ForRange<CUDADeviceContext> { struct ForRange<CUDADeviceContext> {
ForRange(const CUDADeviceContext& dev_ctx, size_t limit) ForRange(const CUDADeviceContext& dev_ctx, size_t limit)
: dev_ctx_(dev_ctx), limit_(static_cast<int>(limit)) {} : dev_ctx_(dev_ctx), limit_(static_cast<size_t>(limit)) {}
template <typename Function> template <typename Function>
inline void operator()(Function func) const { inline void operator()(Function func) const {
constexpr int num_threads = 1024; constexpr int num_threads = 1024;
int block_size = limit_ <= num_threads ? limit_ : num_threads; size_t block_size = limit_ <= num_threads ? limit_ : num_threads;
int grid_size = (limit_ + num_threads - 1) / num_threads; size_t grid_size = (limit_ + num_threads - 1) / num_threads;
if (grid_size == 1) { if (grid_size == 1) {
ForRangeElemwiseOpGridIsOne<<<1, block_size, 0, dev_ctx_.stream()>>>( ForRangeElemwiseOpGridIsOne<<<1, block_size, 0, dev_ctx_.stream()>>>(
...@@ -76,7 +76,7 @@ struct ForRange<CUDADeviceContext> { ...@@ -76,7 +76,7 @@ struct ForRange<CUDADeviceContext> {
} }
const CUDADeviceContext& dev_ctx_; const CUDADeviceContext& dev_ctx_;
int limit_; size_t limit_;
}; };
#endif #endif
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册