From e896567eefbcd80be25ec1f899003ff5fbe17e1b Mon Sep 17 00:00:00 2001 From: sneaxiy <32832641+sneaxiy@users.noreply.github.com> Date: Wed, 12 Oct 2022 11:32:42 +0800 Subject: [PATCH] Fix some operators when the tensor.numel() > INT32_MAX (#46767) * fix some ops for int64 range * update error message --- .../fused_softmax_mask_upper_triangle_op.cu | 73 ++++++++++++------- paddle/phi/kernels/funcs/elementwise_base.h | 6 +- 2 files changed, 49 insertions(+), 30 deletions(-) mode change 100755 => 100644 paddle/phi/kernels/funcs/elementwise_base.h diff --git a/paddle/fluid/operators/fused_softmax_mask_upper_triangle_op.cu b/paddle/fluid/operators/fused_softmax_mask_upper_triangle_op.cu index 4a592508474..41c42e6134e 100644 --- a/paddle/fluid/operators/fused_softmax_mask_upper_triangle_op.cu +++ b/paddle/fluid/operators/fused_softmax_mask_upper_triangle_op.cu @@ -127,26 +127,27 @@ __device__ __forceinline__ void warp_reduce_upper_tri(T* sum) { template __global__ void SoftmaxMaskFuseUpperTriangleGPUKernel(const T* src, T* dst, - int batch_count, - int key_seq_len) { + int64_t batch_count, + int64_t key_seq_len) { constexpr int next_pow2 = 1 << pow2_index; constexpr int warp_size = (next_pow2 < WARP_SIZE) ? next_pow2 : WARP_SIZE; constexpr int kLocalIterations = std::max(next_pow2 / warp_size, 4); constexpr int kLocalBatchSize = (next_pow2 <= 128) ? 2 : 1; constexpr int kOneLoadingCounts = 4; - int key_seq_len_pow_2 = key_seq_len * key_seq_len; + int64_t key_seq_len_pow_2 = key_seq_len * key_seq_len; - int first_idx = - (blockDim.y * blockIdx.y + threadIdx.y) * gridDim.x * kLocalBatchSize + + int64_t first_idx = + (static_cast(blockDim.y) * blockIdx.y + threadIdx.y) * + gridDim.x * kLocalBatchSize + blockIdx.x; - int local_block_idx = blockIdx.x + 1; - int warp_iter_upper_bound = + int64_t local_block_idx = blockIdx.x + 1; + int64_t warp_iter_upper_bound = (local_block_idx + kOneLoadingCounts * warp_size - 1) / warp_size; - int local_batches = batch_count - first_idx; + int64_t local_batches = batch_count - first_idx; if (local_batches > kLocalBatchSize) local_batches = kLocalBatchSize; - int local_idx = threadIdx.x; + int64_t local_idx = threadIdx.x; src += first_idx * key_seq_len + kOneLoadingCounts * local_idx; dst += first_idx * key_seq_len + kOneLoadingCounts * local_idx; @@ -156,11 +157,11 @@ __global__ void SoftmaxMaskFuseUpperTriangleGPUKernel(const T* src, #pragma unroll for (int i = 0; i < kLocalBatchSize; ++i) { - int batch_total_number = (i >= local_batches) ? 0 : local_block_idx; + auto batch_total_number = (i >= local_batches) ? 0 : local_block_idx; #pragma unroll for (int ii = 0; ii < kLocalIterations; ii += kOneLoadingCounts) { - int element_index = kOneLoadingCounts * local_idx + ii * warp_size; + auto element_index = kOneLoadingCounts * local_idx + ii * warp_size; if (element_index < batch_total_number) { load_data_upper_tri(temp_in, @@ -215,7 +216,7 @@ __global__ void SoftmaxMaskFuseUpperTriangleGPUKernel(const T* src, if (i >= local_batches) break; #pragma unroll for (int ii = 0; ii < kLocalIterations; ii += kOneLoadingCounts) { - int element_index = kOneLoadingCounts * local_idx + ii * warp_size; + auto element_index = kOneLoadingCounts * local_idx + ii * warp_size; if (element_index < local_block_idx) { #pragma unroll @@ -241,31 +242,32 @@ template __global__ void SoftmaxMaskFuseUpperTriangleGradGPUKernel(const T* grad_input, T* grad_output, const T* softmax_rst, - int batch_count, - int key_seq_len) { + int64_t batch_count, + int64_t key_seq_len) { constexpr int next_pow2 = 1 << pow2_index; constexpr int warp_size = (next_pow2 < WARP_SIZE) ? next_pow2 : WARP_SIZE; constexpr int kLocalIterations = std::max(next_pow2 / warp_size, 4); constexpr int kLocalBatchSize = (next_pow2 <= 128) ? 2 : 1; constexpr int kOneLoadingCounts = 4; - int key_seq_len_pow_2 = key_seq_len * key_seq_len; + int64_t key_seq_len_pow_2 = key_seq_len * key_seq_len; - int first_idx = - (blockDim.y * blockIdx.y + threadIdx.y) * gridDim.x * kLocalBatchSize + + int64_t first_idx = + (static_cast(blockDim.y) * blockIdx.y + threadIdx.y) * + gridDim.x * kLocalBatchSize + blockIdx.x; - int local_block_idx = blockIdx.x + 1; + int64_t local_block_idx = blockIdx.x + 1; // micro_batch_size might not be a multiple of WARP_BATCH. Check how // many batches have to computed within this WARP. - int local_batches = batch_count - first_idx; + int64_t local_batches = batch_count - first_idx; if (local_batches > kLocalBatchSize) local_batches = kLocalBatchSize; // there might be multiple batches per warp. compute the index within the // batch - int local_idx = threadIdx.x; + int64_t local_idx = threadIdx.x; // the first element to process by the current thread - int offset = first_idx * key_seq_len + kOneLoadingCounts * local_idx; + int64_t offset = first_idx * key_seq_len + kOneLoadingCounts * local_idx; grad_input += offset; grad_output += offset; softmax_rst += offset; @@ -278,11 +280,11 @@ __global__ void SoftmaxMaskFuseUpperTriangleGradGPUKernel(const T* grad_input, #pragma unroll for (int i = 0; i < kLocalBatchSize; ++i) { - int batch_total_number = (i >= local_batches) ? 0 : local_block_idx; + auto batch_total_number = (i >= local_batches) ? 0 : local_block_idx; #pragma unroll for (int ii = 0; ii < kLocalIterations; ii += kOneLoadingCounts) { - int element_index = kOneLoadingCounts * local_idx + ii * warp_size; + auto element_index = kOneLoadingCounts * local_idx + ii * warp_size; if (element_index < batch_total_number) { load_data_upper_tri( temp_grad_input, @@ -327,7 +329,7 @@ __global__ void SoftmaxMaskFuseUpperTriangleGradGPUKernel(const T* grad_input, if (i >= local_batches) break; #pragma unroll for (int ii = 0; ii < kLocalIterations; ii += kOneLoadingCounts) { - int element_index = kOneLoadingCounts * local_idx + ii * warp_size; + auto element_index = kOneLoadingCounts * local_idx + ii * warp_size; if (element_index < key_seq_len) { // compute gradients T samples_out[kOneLoadingCounts]; @@ -368,10 +370,10 @@ class SoftmaxMaskFuseUpperTriangleKernel : public framework::OpKernel { key_seq_len, query_seq_len)); - PADDLE_ENFORCE_EQ(key_seq_len >= 32 && key_seq_len < 8192, + PADDLE_ENFORCE_EQ(key_seq_len >= 32 && key_seq_len <= 16384, true, platform::errors::InvalidArgument( - "Input x's last dim must be between [32, 8192) " + "Input x's last dim must be between [32, 16384] " "received the last dimension of x is %d", key_seq_len)); @@ -380,7 +382,7 @@ class SoftmaxMaskFuseUpperTriangleKernel : public framework::OpKernel { int pow2_index = get_pow2_index_value(key_seq_len); const int next_pow2 = 1 << pow2_index; - int batch_count = attn_mul_batch * query_seq_len; + int64_t batch_count = attn_mul_batch * query_seq_len; int warp_size = (next_pow2 < WARP_SIZE) ? next_pow2 : WARP_SIZE; int batches_per_warp = (next_pow2 <= 128) ? 2 : 1; constexpr int threads_per_block = 128; @@ -447,7 +449,13 @@ class SoftmaxMaskFuseUpperTriangleKernel : public framework::OpKernel { <<>>( x_data, y_data, batch_count, key_seq_len); break; + case 14: // 16384 + SoftmaxMaskFuseUpperTriangleGPUKernel + <<>>( + x_data, y_data, batch_count, key_seq_len); + break; default: + PADDLE_THROW(phi::errors::Unimplemented("Too large sequence length.")); break; } } @@ -479,7 +487,7 @@ class SoftmaxMaskFuseUpperTriangleGradKernel : public framework::OpKernel { int pow2_index = get_pow2_index_value(key_seq_len); const int next_pow2 = 1 << pow2_index; - int batch_count = attn_mul_batch * query_seq_len; + int64_t batch_count = attn_mul_batch * query_seq_len; int warp_size = (next_pow2 < WARP_SIZE) ? next_pow2 : WARP_SIZE; int batches_per_warp = (next_pow2 <= 128) ? 2 : 1; // use 128 threads per block to maximum gpu utilization @@ -565,7 +573,16 @@ class SoftmaxMaskFuseUpperTriangleGradKernel : public framework::OpKernel { batch_count, key_seq_len); break; + case 14: + SoftmaxMaskFuseUpperTriangleGradGPUKernel + <<>>(grad_y_data, + grad_x_data, + softmax_rst_data, + batch_count, + key_seq_len); + break; default: + PADDLE_THROW(phi::errors::Unimplemented("Too large sequence length.")); break; } } diff --git a/paddle/phi/kernels/funcs/elementwise_base.h b/paddle/phi/kernels/funcs/elementwise_base.h old mode 100755 new mode 100644 index 2573a0e44c9..100d2dcd612 --- a/paddle/phi/kernels/funcs/elementwise_base.h +++ b/paddle/phi/kernels/funcs/elementwise_base.h @@ -760,8 +760,10 @@ __global__ void VectorizedElementwiseKernel( kps::IndexType main_offset, int read_lens, Functor func) { - kps::IndexType data_offset = BLOCK_ID_X * BLOCK_NUM_X * read_lens; - kps::IndexType stride = BLOCK_NUM_X * GRID_NUM_X * read_lens; + kps::IndexType data_offset = + static_cast(BLOCK_ID_X) * BLOCK_NUM_X * read_lens; + kps::IndexType stride = + static_cast(BLOCK_NUM_X) * GRID_NUM_X * read_lens; for (; data_offset < main_offset; data_offset += stride) { VectorizedElementwiseKernelImpl