From 393b3bd6b7adadedc21d801c68c5bd002047fdc3 Mon Sep 17 00:00:00 2001 From: Thunderbrook <52529258+Thunderbrook@users.noreply.github.com> Date: Wed, 31 Mar 2021 11:14:06 +0800 Subject: [PATCH] fix split core (#31892) * fix split core * format --- .../fluid/operators/math/concat_and_split.cu | 48 +++++++++---------- 1 file changed, 24 insertions(+), 24 deletions(-) diff --git a/paddle/fluid/operators/math/concat_and_split.cu b/paddle/fluid/operators/math/concat_and_split.cu index a29997e5654..d62c1e42d3b 100644 --- a/paddle/fluid/operators/math/concat_and_split.cu +++ b/paddle/fluid/operators/math/concat_and_split.cu @@ -114,8 +114,8 @@ __global__ void ConcatKernel(const T** inputs_data, const int in_num, } template -__global__ void SplitKernel(const T* input_data, const int in_row, - const int in_col, const int* out_cols, +__global__ void SplitKernel(const T* input_data, const int64_t in_row, + const int64_t in_col, const int64_t* out_cols, int out_cols_size, T** outputs_data) { int tid_x = blockIdx.x * blockDim.x + threadIdx.x; int curr_segment = 0; @@ -159,15 +159,15 @@ __device__ void SplitKernelDetail(const T* input_data, const int in_row, } template -__global__ void SplitKernel(const T* input_data, const int in_row, - const int in_col, const int fixed_out_col, +__global__ void SplitKernel(const T* input_data, const int64_t in_row, + const int64_t in_col, const int64_t fixed_out_col, T** outputs_data) { SplitKernelDetail(input_data, in_row, in_col, fixed_out_col, outputs_data); } template -__global__ void SplitKernel(const T* input_data, const int in_row, - const int in_col, const int fixed_out_col, +__global__ void SplitKernel(const T* input_data, const int64_t in_row, + const int64_t in_col, const int64_t fixed_out_col, T* outputs_addr0, T* outputs_addr1) { T* outputs_data[2]; outputs_data[0] = outputs_addr0; @@ -176,8 +176,8 @@ __global__ void SplitKernel(const T* input_data, const int in_row, } template -__global__ void SplitKernel(const T* input_data, const int in_row, - const int in_col, const int fixed_out_col, +__global__ void SplitKernel(const T* input_data, const int64_t in_row, + const int64_t in_col, const int64_t fixed_out_col, T* outputs_addr0, T* outputs_addr1, T* outputs_addr2) { T* outputs_data[3]; @@ -188,8 +188,8 @@ __global__ void SplitKernel(const T* input_data, const int in_row, } template -__global__ void SplitKernel(const T* input_data, const int in_row, - const int in_col, const int fixed_out_col, +__global__ void SplitKernel(const T* input_data, const int64_t in_row, + const int64_t in_col, const int64_t fixed_out_col, T* outputs_addr0, T* outputs_addr1, T* outputs_addr2, T* outputs_addr3) { T* outputs_data[4]; @@ -201,8 +201,8 @@ __global__ void SplitKernel(const T* input_data, const int in_row, } static inline void GetBlockDims(const platform::CUDADeviceContext& context, - int num_rows, int num_cols, dim3* block_dims, - dim3* grid_dims) { + int64_t num_rows, int64_t num_cols, + dim3* block_dims, dim3* grid_dims) { // Set the thread block and grid according to CurrentDeviceId const int kThreadsPerBlock = 1024; int block_cols = kThreadsPerBlock; @@ -213,12 +213,12 @@ static inline void GetBlockDims(const platform::CUDADeviceContext& context, *block_dims = dim3(block_cols, block_rows, 1); int max_threads = context.GetMaxPhysicalThreadCount(); - int max_blocks = std::max(max_threads / kThreadsPerBlock, 1); + int64_t max_blocks = std::max(max_threads / kThreadsPerBlock, 1); int grid_cols = std::min((num_cols + block_cols - 1) / block_cols, max_blocks); - int grid_rows = - std::min(max_blocks / grid_cols, std::max(num_rows / block_rows, 1)); + int grid_rows = std::min(max_blocks / grid_cols, + std::max(num_rows / block_rows, (int64_t)1)); *grid_dims = dim3(grid_cols, grid_rows, 1); } @@ -319,22 +319,22 @@ class SplitFunctor { int axis, std::vector* outputs) { // TODO(zcd): Add input data validity checking int o_num = outputs->size(); - int out_row = 1; + int64_t out_row = 1; auto dim_0 = ref_inputs[0]->dims(); for (int i = 0; i < axis; ++i) { out_row *= dim_0[i]; } - int out0_col = ref_inputs[0]->numel() / out_row; - int in_col = 0, in_row = out_row; + int64_t out0_col = ref_inputs[0]->numel() / out_row; + int64_t in_col = 0, in_row = out_row; bool has_same_shape = true; std::vector outputs_data(o_num); - std::vector outputs_cols(o_num + 1); + std::vector outputs_cols(o_num + 1); outputs_cols[0] = 0; for (int i = 0; i < o_num; ++i) { - int t_col = ref_inputs.at(i)->numel() / out_row; + int64_t t_col = ref_inputs.at(i)->numel() / out_row; if (has_same_shape) { if (t_col != out0_col) has_same_shape = false; } @@ -384,13 +384,13 @@ class SplitFunctor { auto tmp_dev_ins_col_data = memory::Alloc(context, - outputs_cols.size() * sizeof(int)); + outputs_cols.size() * sizeof(int64_t)); memory::Copy(BOOST_GET_CONST(platform::CUDAPlace, context.GetPlace()), tmp_dev_ins_col_data->ptr(), platform::CPUPlace(), reinterpret_cast(outputs_cols.data()), - outputs_cols.size() * sizeof(int), context.stream()); - int* dev_outs_col_data = - reinterpret_cast(tmp_dev_ins_col_data->ptr()); + outputs_cols.size() * sizeof(int64_t), context.stream()); + int64_t* dev_outs_col_data = + reinterpret_cast(tmp_dev_ins_col_data->ptr()); SplitKernel<<>>( input.data(), in_row, in_col, dev_outs_col_data, -- GitLab