未验证 提交 e896567e 编写于 作者: S sneaxiy 提交者: GitHub

Fix some operators when the tensor.numel() > INT32_MAX (#46767)

* fix some ops for int64 range

* update error message
上级 05c2b9ba
...@@ -127,26 +127,27 @@ __device__ __forceinline__ void warp_reduce_upper_tri(T* sum) { ...@@ -127,26 +127,27 @@ __device__ __forceinline__ void warp_reduce_upper_tri(T* sum) {
template <typename T, int pow2_index> template <typename T, int pow2_index>
__global__ void SoftmaxMaskFuseUpperTriangleGPUKernel(const T* src, __global__ void SoftmaxMaskFuseUpperTriangleGPUKernel(const T* src,
T* dst, T* dst,
int batch_count, int64_t batch_count,
int key_seq_len) { int64_t key_seq_len) {
constexpr int next_pow2 = 1 << pow2_index; constexpr int next_pow2 = 1 << pow2_index;
constexpr int warp_size = (next_pow2 < WARP_SIZE) ? next_pow2 : WARP_SIZE; constexpr int warp_size = (next_pow2 < WARP_SIZE) ? next_pow2 : WARP_SIZE;
constexpr int kLocalIterations = std::max(next_pow2 / warp_size, 4); constexpr int kLocalIterations = std::max(next_pow2 / warp_size, 4);
constexpr int kLocalBatchSize = (next_pow2 <= 128) ? 2 : 1; constexpr int kLocalBatchSize = (next_pow2 <= 128) ? 2 : 1;
constexpr int kOneLoadingCounts = 4; constexpr int kOneLoadingCounts = 4;
int key_seq_len_pow_2 = key_seq_len * key_seq_len; int64_t key_seq_len_pow_2 = key_seq_len * key_seq_len;
int first_idx = int64_t first_idx =
(blockDim.y * blockIdx.y + threadIdx.y) * gridDim.x * kLocalBatchSize + (static_cast<int64_t>(blockDim.y) * blockIdx.y + threadIdx.y) *
gridDim.x * kLocalBatchSize +
blockIdx.x; blockIdx.x;
int local_block_idx = blockIdx.x + 1; int64_t local_block_idx = blockIdx.x + 1;
int warp_iter_upper_bound = int64_t warp_iter_upper_bound =
(local_block_idx + kOneLoadingCounts * warp_size - 1) / warp_size; (local_block_idx + kOneLoadingCounts * warp_size - 1) / warp_size;
int local_batches = batch_count - first_idx; int64_t local_batches = batch_count - first_idx;
if (local_batches > kLocalBatchSize) local_batches = kLocalBatchSize; if (local_batches > kLocalBatchSize) local_batches = kLocalBatchSize;
int local_idx = threadIdx.x; int64_t local_idx = threadIdx.x;
src += first_idx * key_seq_len + kOneLoadingCounts * local_idx; src += first_idx * key_seq_len + kOneLoadingCounts * local_idx;
dst += first_idx * key_seq_len + kOneLoadingCounts * local_idx; dst += first_idx * key_seq_len + kOneLoadingCounts * local_idx;
...@@ -156,11 +157,11 @@ __global__ void SoftmaxMaskFuseUpperTriangleGPUKernel(const T* src, ...@@ -156,11 +157,11 @@ __global__ void SoftmaxMaskFuseUpperTriangleGPUKernel(const T* src,
#pragma unroll #pragma unroll
for (int i = 0; i < kLocalBatchSize; ++i) { for (int i = 0; i < kLocalBatchSize; ++i) {
int batch_total_number = (i >= local_batches) ? 0 : local_block_idx; auto batch_total_number = (i >= local_batches) ? 0 : local_block_idx;
#pragma unroll #pragma unroll
for (int ii = 0; ii < kLocalIterations; ii += kOneLoadingCounts) { for (int ii = 0; ii < kLocalIterations; ii += kOneLoadingCounts) {
int element_index = kOneLoadingCounts * local_idx + ii * warp_size; auto element_index = kOneLoadingCounts * local_idx + ii * warp_size;
if (element_index < batch_total_number) { if (element_index < batch_total_number) {
load_data_upper_tri(temp_in, load_data_upper_tri(temp_in,
...@@ -215,7 +216,7 @@ __global__ void SoftmaxMaskFuseUpperTriangleGPUKernel(const T* src, ...@@ -215,7 +216,7 @@ __global__ void SoftmaxMaskFuseUpperTriangleGPUKernel(const T* src,
if (i >= local_batches) break; if (i >= local_batches) break;
#pragma unroll #pragma unroll
for (int ii = 0; ii < kLocalIterations; ii += kOneLoadingCounts) { for (int ii = 0; ii < kLocalIterations; ii += kOneLoadingCounts) {
int element_index = kOneLoadingCounts * local_idx + ii * warp_size; auto element_index = kOneLoadingCounts * local_idx + ii * warp_size;
if (element_index < local_block_idx) { if (element_index < local_block_idx) {
#pragma unroll #pragma unroll
...@@ -241,31 +242,32 @@ template <typename T, int pow2_index> ...@@ -241,31 +242,32 @@ template <typename T, int pow2_index>
__global__ void SoftmaxMaskFuseUpperTriangleGradGPUKernel(const T* grad_input, __global__ void SoftmaxMaskFuseUpperTriangleGradGPUKernel(const T* grad_input,
T* grad_output, T* grad_output,
const T* softmax_rst, const T* softmax_rst,
int batch_count, int64_t batch_count,
int key_seq_len) { int64_t key_seq_len) {
constexpr int next_pow2 = 1 << pow2_index; constexpr int next_pow2 = 1 << pow2_index;
constexpr int warp_size = (next_pow2 < WARP_SIZE) ? next_pow2 : WARP_SIZE; constexpr int warp_size = (next_pow2 < WARP_SIZE) ? next_pow2 : WARP_SIZE;
constexpr int kLocalIterations = std::max(next_pow2 / warp_size, 4); constexpr int kLocalIterations = std::max(next_pow2 / warp_size, 4);
constexpr int kLocalBatchSize = (next_pow2 <= 128) ? 2 : 1; constexpr int kLocalBatchSize = (next_pow2 <= 128) ? 2 : 1;
constexpr int kOneLoadingCounts = 4; constexpr int kOneLoadingCounts = 4;
int key_seq_len_pow_2 = key_seq_len * key_seq_len; int64_t key_seq_len_pow_2 = key_seq_len * key_seq_len;
int first_idx = int64_t first_idx =
(blockDim.y * blockIdx.y + threadIdx.y) * gridDim.x * kLocalBatchSize + (static_cast<int64_t>(blockDim.y) * blockIdx.y + threadIdx.y) *
gridDim.x * kLocalBatchSize +
blockIdx.x; blockIdx.x;
int local_block_idx = blockIdx.x + 1; int64_t local_block_idx = blockIdx.x + 1;
// micro_batch_size might not be a multiple of WARP_BATCH. Check how // micro_batch_size might not be a multiple of WARP_BATCH. Check how
// many batches have to computed within this WARP. // many batches have to computed within this WARP.
int local_batches = batch_count - first_idx; int64_t local_batches = batch_count - first_idx;
if (local_batches > kLocalBatchSize) local_batches = kLocalBatchSize; if (local_batches > kLocalBatchSize) local_batches = kLocalBatchSize;
// there might be multiple batches per warp. compute the index within the // there might be multiple batches per warp. compute the index within the
// batch // batch
int local_idx = threadIdx.x; int64_t local_idx = threadIdx.x;
// the first element to process by the current thread // the first element to process by the current thread
int offset = first_idx * key_seq_len + kOneLoadingCounts * local_idx; int64_t offset = first_idx * key_seq_len + kOneLoadingCounts * local_idx;
grad_input += offset; grad_input += offset;
grad_output += offset; grad_output += offset;
softmax_rst += offset; softmax_rst += offset;
...@@ -278,11 +280,11 @@ __global__ void SoftmaxMaskFuseUpperTriangleGradGPUKernel(const T* grad_input, ...@@ -278,11 +280,11 @@ __global__ void SoftmaxMaskFuseUpperTriangleGradGPUKernel(const T* grad_input,
#pragma unroll #pragma unroll
for (int i = 0; i < kLocalBatchSize; ++i) { for (int i = 0; i < kLocalBatchSize; ++i) {
int batch_total_number = (i >= local_batches) ? 0 : local_block_idx; auto batch_total_number = (i >= local_batches) ? 0 : local_block_idx;
#pragma unroll #pragma unroll
for (int ii = 0; ii < kLocalIterations; ii += kOneLoadingCounts) { for (int ii = 0; ii < kLocalIterations; ii += kOneLoadingCounts) {
int element_index = kOneLoadingCounts * local_idx + ii * warp_size; auto element_index = kOneLoadingCounts * local_idx + ii * warp_size;
if (element_index < batch_total_number) { if (element_index < batch_total_number) {
load_data_upper_tri( load_data_upper_tri(
temp_grad_input, temp_grad_input,
...@@ -327,7 +329,7 @@ __global__ void SoftmaxMaskFuseUpperTriangleGradGPUKernel(const T* grad_input, ...@@ -327,7 +329,7 @@ __global__ void SoftmaxMaskFuseUpperTriangleGradGPUKernel(const T* grad_input,
if (i >= local_batches) break; if (i >= local_batches) break;
#pragma unroll #pragma unroll
for (int ii = 0; ii < kLocalIterations; ii += kOneLoadingCounts) { for (int ii = 0; ii < kLocalIterations; ii += kOneLoadingCounts) {
int element_index = kOneLoadingCounts * local_idx + ii * warp_size; auto element_index = kOneLoadingCounts * local_idx + ii * warp_size;
if (element_index < key_seq_len) { if (element_index < key_seq_len) {
// compute gradients // compute gradients
T samples_out[kOneLoadingCounts]; T samples_out[kOneLoadingCounts];
...@@ -368,10 +370,10 @@ class SoftmaxMaskFuseUpperTriangleKernel : public framework::OpKernel<T> { ...@@ -368,10 +370,10 @@ class SoftmaxMaskFuseUpperTriangleKernel : public framework::OpKernel<T> {
key_seq_len, key_seq_len,
query_seq_len)); query_seq_len));
PADDLE_ENFORCE_EQ(key_seq_len >= 32 && key_seq_len < 8192, PADDLE_ENFORCE_EQ(key_seq_len >= 32 && key_seq_len <= 16384,
true, true,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"Input x's last dim must be between [32, 8192) " "Input x's last dim must be between [32, 16384] "
"received the last dimension of x is %d", "received the last dimension of x is %d",
key_seq_len)); key_seq_len));
...@@ -380,7 +382,7 @@ class SoftmaxMaskFuseUpperTriangleKernel : public framework::OpKernel<T> { ...@@ -380,7 +382,7 @@ class SoftmaxMaskFuseUpperTriangleKernel : public framework::OpKernel<T> {
int pow2_index = get_pow2_index_value(key_seq_len); int pow2_index = get_pow2_index_value(key_seq_len);
const int next_pow2 = 1 << pow2_index; const int next_pow2 = 1 << pow2_index;
int batch_count = attn_mul_batch * query_seq_len; int64_t batch_count = attn_mul_batch * query_seq_len;
int warp_size = (next_pow2 < WARP_SIZE) ? next_pow2 : WARP_SIZE; int warp_size = (next_pow2 < WARP_SIZE) ? next_pow2 : WARP_SIZE;
int batches_per_warp = (next_pow2 <= 128) ? 2 : 1; int batches_per_warp = (next_pow2 <= 128) ? 2 : 1;
constexpr int threads_per_block = 128; constexpr int threads_per_block = 128;
...@@ -447,7 +449,13 @@ class SoftmaxMaskFuseUpperTriangleKernel : public framework::OpKernel<T> { ...@@ -447,7 +449,13 @@ class SoftmaxMaskFuseUpperTriangleKernel : public framework::OpKernel<T> {
<<<blocks, threads, 0, stream>>>( <<<blocks, threads, 0, stream>>>(
x_data, y_data, batch_count, key_seq_len); x_data, y_data, batch_count, key_seq_len);
break; break;
case 14: // 16384
SoftmaxMaskFuseUpperTriangleGPUKernel<T, 14>
<<<blocks, threads, 0, stream>>>(
x_data, y_data, batch_count, key_seq_len);
break;
default: default:
PADDLE_THROW(phi::errors::Unimplemented("Too large sequence length."));
break; break;
} }
} }
...@@ -479,7 +487,7 @@ class SoftmaxMaskFuseUpperTriangleGradKernel : public framework::OpKernel<T> { ...@@ -479,7 +487,7 @@ class SoftmaxMaskFuseUpperTriangleGradKernel : public framework::OpKernel<T> {
int pow2_index = get_pow2_index_value(key_seq_len); int pow2_index = get_pow2_index_value(key_seq_len);
const int next_pow2 = 1 << pow2_index; const int next_pow2 = 1 << pow2_index;
int batch_count = attn_mul_batch * query_seq_len; int64_t batch_count = attn_mul_batch * query_seq_len;
int warp_size = (next_pow2 < WARP_SIZE) ? next_pow2 : WARP_SIZE; int warp_size = (next_pow2 < WARP_SIZE) ? next_pow2 : WARP_SIZE;
int batches_per_warp = (next_pow2 <= 128) ? 2 : 1; int batches_per_warp = (next_pow2 <= 128) ? 2 : 1;
// use 128 threads per block to maximum gpu utilization // use 128 threads per block to maximum gpu utilization
...@@ -565,7 +573,16 @@ class SoftmaxMaskFuseUpperTriangleGradKernel : public framework::OpKernel<T> { ...@@ -565,7 +573,16 @@ class SoftmaxMaskFuseUpperTriangleGradKernel : public framework::OpKernel<T> {
batch_count, batch_count,
key_seq_len); key_seq_len);
break; break;
case 14:
SoftmaxMaskFuseUpperTriangleGradGPUKernel<T, 14>
<<<blocks, threads, 0, stream>>>(grad_y_data,
grad_x_data,
softmax_rst_data,
batch_count,
key_seq_len);
break;
default: default:
PADDLE_THROW(phi::errors::Unimplemented("Too large sequence length."));
break; break;
} }
} }
......
...@@ -760,8 +760,10 @@ __global__ void VectorizedElementwiseKernel( ...@@ -760,8 +760,10 @@ __global__ void VectorizedElementwiseKernel(
kps::IndexType main_offset, kps::IndexType main_offset,
int read_lens, int read_lens,
Functor func) { Functor func) {
kps::IndexType data_offset = BLOCK_ID_X * BLOCK_NUM_X * read_lens; kps::IndexType data_offset =
kps::IndexType stride = BLOCK_NUM_X * GRID_NUM_X * read_lens; static_cast<kps::IndexType>(BLOCK_ID_X) * BLOCK_NUM_X * read_lens;
kps::IndexType stride =
static_cast<kps::IndexType>(BLOCK_NUM_X) * GRID_NUM_X * read_lens;
for (; data_offset < main_offset; data_offset += stride) { for (; data_offset < main_offset; data_offset += stride) {
VectorizedElementwiseKernelImpl<OutT, VectorizedElementwiseKernelImpl<OutT,
Functor, Functor,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册