diff --git a/paddle/phi/kernels/funcs/concat_and_split_functor.cu b/paddle/phi/kernels/funcs/concat_and_split_functor.cu index 22dba8297d65bb946ad81fb0e959a9fd47882c73..dbcd4016170d57515df67d4b8274aab41685ba73 100644 --- a/paddle/phi/kernels/funcs/concat_and_split_functor.cu +++ b/paddle/phi/kernels/funcs/concat_and_split_functor.cu @@ -26,22 +26,21 @@ __global__ void ConcatKernel_(const T** inputs, const int64_t output_rows, const int64_t output_cols, T* output) { - int tid_x = blockIdx.x * blockDim.x + threadIdx.x; - int curr_segment = 0; - int curr_offset = input_cols[0]; - for (; tid_x < output_cols; tid_x += blockDim.x * gridDim.x) { - int curr_col_offset = input_cols[curr_segment + 1]; + int64_t curr_segment = 0; + int64_t curr_offset = input_cols[0]; + CUDA_KERNEL_LOOP_TYPE(tid_x, output_cols, int64_t) { + int64_t curr_col_offset = input_cols[curr_segment + 1]; while (curr_col_offset <= tid_x) { curr_offset = curr_col_offset; ++curr_segment; curr_col_offset = input_cols[curr_segment + 1]; } - int local_col = tid_x - curr_offset; - int segment_width = curr_col_offset - curr_offset; + int64_t local_col = tid_x - curr_offset; + int64_t segment_width = curr_col_offset - curr_offset; const T* input_ptr = inputs[curr_segment]; - int tid_y = blockIdx.y * blockDim.y + threadIdx.y; + int64_t tid_y = blockIdx.y * blockDim.y + threadIdx.y; for (; tid_y < output_rows; tid_y += blockDim.y * gridDim.y) output[tid_y * output_cols + tid_x] = input_ptr[tid_y * segment_width + local_col]; @@ -50,16 +49,15 @@ __global__ void ConcatKernel_(const T** inputs, template __device__ void ConcatKernelDetail(const T** inputs_data, - const int fixed_in_col, - const int out_rows, - const int out_cols, + const int64_t fixed_in_col, + const int64_t out_rows, + const int64_t out_cols, T* output_data) { - int tid_x = blockIdx.x * blockDim.x + threadIdx.x; - for (; tid_x < out_cols; tid_x += blockDim.x * gridDim.x) { - int split = tid_x * 1.0 / fixed_in_col; - int in_offset = tid_x - split * fixed_in_col; + CUDA_KERNEL_LOOP_TYPE(tid_x, out_cols, int64_t) { + int64_t split = tid_x * 1.0 / fixed_in_col; + int64_t in_offset = tid_x - split * fixed_in_col; const T* input_ptr = inputs_data[split]; - int tid_y = blockIdx.y * blockDim.y + threadIdx.y; + int64_t tid_y = blockIdx.y * blockDim.y + threadIdx.y; for (; tid_y < out_rows; tid_y += blockDim.y * gridDim.y) { output_data[tid_y * out_cols + tid_x] = input_ptr[tid_y * fixed_in_col + in_offset]; @@ -133,22 +131,21 @@ __global__ void SplitKernel_(const T* input_data, const int64_t* out_cols, int out_cols_size, T** outputs_data) { - int tid_x = blockIdx.x * blockDim.x + threadIdx.x; - int curr_segment = 0; - int curr_offset = out_cols[0]; - for (; tid_x < in_col; tid_x += blockDim.x * gridDim.x) { - int curr_col_offset = out_cols[curr_segment + 1]; + int64_t curr_segment = 0; + int64_t curr_offset = out_cols[0]; + CUDA_KERNEL_LOOP_TYPE(tid_x, in_col, int64_t) { + int64_t curr_col_offset = out_cols[curr_segment + 1]; while (curr_col_offset <= tid_x) { curr_offset = curr_col_offset; ++curr_segment; curr_col_offset = out_cols[curr_segment + 1]; } - int local_col = tid_x - curr_offset; - int segment_width = curr_col_offset - curr_offset; + int64_t local_col = tid_x - curr_offset; + int64_t segment_width = curr_col_offset - curr_offset; T* output_ptr = outputs_data[curr_segment]; if (output_ptr != nullptr) { - int tid_y = blockIdx.y * blockDim.y + threadIdx.y; + int64_t tid_y = blockIdx.y * blockDim.y + threadIdx.y; for (; tid_y < in_row; tid_y += blockDim.y * gridDim.y) output_ptr[tid_y * segment_width + local_col] = input_data[tid_y * in_col + tid_x]; @@ -158,17 +155,16 @@ __global__ void SplitKernel_(const T* input_data, template __device__ void SplitKernelDetail(const T* input_data, - const int in_row, - const int in_col, - const int fixed_out_col, + const int64_t in_row, + const int64_t in_col, + const int64_t fixed_out_col, T** outputs_data) { - int tid_x = blockIdx.x * blockDim.x + threadIdx.x; - for (; tid_x < in_col; tid_x += blockDim.x * gridDim.x) { - int split = tid_x / fixed_out_col; - int in_offset = tid_x - split * fixed_out_col; + CUDA_KERNEL_LOOP_TYPE(tid_x, in_col, int64_t) { + int64_t split = tid_x / fixed_out_col; + int64_t in_offset = tid_x - split * fixed_out_col; T* output_ptr = outputs_data[split]; if (output_ptr != nullptr) { - int tid_y = blockIdx.y * blockDim.y + threadIdx.y; + int64_t tid_y = blockIdx.y * blockDim.y + threadIdx.y; for (; tid_y < in_row; tid_y += blockDim.y * gridDim.y) output_ptr[tid_y * fixed_out_col + in_offset] = input_data[tid_y * in_col + tid_x]; @@ -266,7 +262,7 @@ struct ConcatFunctor { int axis, phi::DenseTensor* output) { // TODO(zcd): Add input data validity checking - int in_num = input.size(); + int64_t in_num = input.size(); int64_t in_row = 1; auto dim_0 = input[0].dims(); for (int i = 0; i < axis; ++i) { @@ -275,7 +271,7 @@ struct ConcatFunctor { int64_t in_col = input[0].numel() / in_row; int64_t out_row = in_row, out_col = 0; - int inputs_col_num = in_num + 1; + int64_t inputs_col_num = in_num + 1; std::vector inputs_data_vec(in_num); std::vector inputs_col_vec(inputs_col_num); const T** inputs_data = inputs_data_vec.data();