From c342651e576d42659e93d7faceee24e1c01ac312 Mon Sep 17 00:00:00 2001 From: wuhuachaocoding <77733235+wuhuachaocoding@users.noreply.github.com> Date: Thu, 22 Jul 2021 20:39:10 +0800 Subject: [PATCH] fix concat bug (#34319) --- .../fluid/operators/math/concat_and_split.cc | 18 ++++---- .../fluid/operators/math/concat_and_split.cu | 46 ++++++++++--------- 2 files changed, 33 insertions(+), 31 deletions(-) diff --git a/paddle/fluid/operators/math/concat_and_split.cc b/paddle/fluid/operators/math/concat_and_split.cc index 7df78b321de..6c1ee863737 100644 --- a/paddle/fluid/operators/math/concat_and_split.cc +++ b/paddle/fluid/operators/math/concat_and_split.cc @@ -40,18 +40,18 @@ class ConcatFunctor { const std::vector& input, int axis, framework::Tensor* output) { // TODO(zcd): Add input data validity checking - int num = input.size(); + size_t num = input.size(); - int rows = 1; + int64_t rows = 1; auto dim_0 = input[0].dims(); for (int i = 0; i < axis; ++i) { rows *= dim_0[i]; } - int out_rows = rows, out_cols = 0; + int64_t out_rows = rows, out_cols = 0; std::vector input_cols(input.size()); - for (int i = 0; i < num; ++i) { - int t_cols = input[i].numel() / rows; + for (size_t i = 0; i < num; ++i) { + int64_t t_cols = input[i].numel() / rows; out_cols += t_cols; input_cols[i] = t_cols; } @@ -59,11 +59,11 @@ class ConcatFunctor { // computation auto output_data = output->data(); - int col_idx = 0; - for (int j = 0; j < num; ++j) { - int col_len = input_cols[j]; + int64_t col_idx = 0; + for (size_t j = 0; j < num; ++j) { + int64_t col_len = input_cols[j]; auto input_data = input[j].data(); - for (int k = 0; k < out_rows; ++k) { + for (int64_t k = 0; k < out_rows; ++k) { memory::Copy(cpu_place, output_data + k * out_cols + col_idx, cpu_place, input_data + k * col_len, sizeof(T) * col_len); } diff --git a/paddle/fluid/operators/math/concat_and_split.cu b/paddle/fluid/operators/math/concat_and_split.cu index 58f936788a3..f9cce061383 100644 --- a/paddle/fluid/operators/math/concat_and_split.cu +++ b/paddle/fluid/operators/math/concat_and_split.cu @@ -26,9 +26,9 @@ namespace operators { namespace math { template -__global__ void ConcatKernel(const T** inputs, const int* input_cols, - int col_size, const int output_rows, - const int output_cols, T* output) { +__global__ void ConcatKernel(const T** inputs, const int64_t* input_cols, + int col_size, const int64_t output_rows, + const int64_t output_cols, T* output) { int tid_x = blockIdx.x * blockDim.x + threadIdx.x; int curr_segment = 0; int curr_offset = input_cols[0]; @@ -70,8 +70,8 @@ __device__ void ConcatKernelDetail(const T** inputs_data, template __global__ void ConcatKernel(const T* input_addr0, const T* input_addr1, - const int fixed_in_col, const int out_rows, - const int out_cols, T* output_data) { + const int64_t fixed_in_col, const int64_t out_rows, + const int64_t out_cols, T* output_data) { const T* inputs_data[2]; inputs_data[0] = input_addr0; inputs_data[1] = input_addr1; @@ -81,8 +81,8 @@ __global__ void ConcatKernel(const T* input_addr0, const T* input_addr1, template __global__ void ConcatKernel(const T* input_addr0, const T* input_addr1, - const T* input_addr2, const int fixed_in_col, - const int out_rows, const int out_cols, + const T* input_addr2, const int64_t fixed_in_col, + const int64_t out_rows, const int64_t out_cols, T* output_data) { const T* inputs_data[3]; inputs_data[0] = input_addr0; @@ -95,8 +95,8 @@ __global__ void ConcatKernel(const T* input_addr0, const T* input_addr1, template __global__ void ConcatKernel(const T* input_addr0, const T* input_addr1, const T* input_addr2, const T* input_addr3, - const int fixed_in_col, const int out_rows, - const int out_cols, T* output_data) { + const int64_t fixed_in_col, const int64_t out_rows, + const int64_t out_cols, T* output_data) { const T* inputs_data[4]; inputs_data[0] = input_addr0; inputs_data[1] = input_addr1; @@ -108,8 +108,8 @@ __global__ void ConcatKernel(const T* input_addr0, const T* input_addr1, template __global__ void ConcatKernel(const T** inputs_data, const int in_num, - const int fixed_in_col, const int out_rows, - const int out_cols, T* output_data) { + const int64_t fixed_in_col, const int64_t out_rows, + const int64_t out_cols, T* output_data) { ConcatKernelDetail(inputs_data, fixed_in_col, out_rows, out_cols, output_data); } @@ -235,19 +235,19 @@ class ConcatFunctor { framework::Tensor* output) { // TODO(zcd): Add input data validity checking int in_num = input.size(); - int in_row = 1; + int64_t in_row = 1; auto dim_0 = input[0].dims(); for (int i = 0; i < axis; ++i) { in_row *= dim_0[i]; } - int in_col = input[0].numel() / in_row; - int out_row = in_row, out_col = 0; + int64_t in_col = input[0].numel() / in_row; + int64_t out_row = in_row, out_col = 0; int inputs_col_num = in_num + 1; std::vector inputs_data_vec(in_num); - std::vector inputs_col_vec(inputs_col_num); + std::vector inputs_col_vec(inputs_col_num); const T** inputs_data = inputs_data_vec.data(); - int* inputs_col = inputs_col_vec.data(); + int64_t* inputs_col = inputs_col_vec.data(); // There are some differences between hip runtime and NV runtime. // In NV, when the pageable memory data less than 64K is transferred from @@ -263,13 +263,13 @@ class ConcatFunctor { inputs_data = reinterpret_cast(data_alloc->ptr()); col_alloc = memory::Alloc(platform::CUDAPinnedPlace(), inputs_col_num * sizeof(int)); - inputs_col = reinterpret_cast(col_alloc->ptr()); + inputs_col = reinterpret_cast(col_alloc->ptr()); #endif inputs_col[0] = 0; bool has_same_shape = true; for (int i = 0; i < in_num; ++i) { - int t_cols = input[i].numel() / in_row; + int64_t t_cols = input[i].numel() / in_row; if (has_same_shape) { if (t_cols != in_col) has_same_shape = false; } @@ -312,17 +312,19 @@ class ConcatFunctor { } } else { auto tmp_dev_ins_col_data = - memory::Alloc(context, inputs_col_num * sizeof(int)); + memory::Alloc(context, inputs_col_num * sizeof(int64_t)); memory::Copy(BOOST_GET_CONST(platform::CUDAPlace, context.GetPlace()), tmp_dev_ins_col_data->ptr(), platform::CPUPlace(), - static_cast(inputs_col), inputs_col_num * sizeof(int), - context.stream()); - int* dev_ins_col_data = static_cast(tmp_dev_ins_col_data->ptr()); + static_cast(inputs_col), + inputs_col_num * sizeof(int64_t), context.stream()); + int64_t* dev_ins_col_data = + static_cast(tmp_dev_ins_col_data->ptr()); ConcatKernel<<>>( dev_ins_data, dev_ins_col_data, static_cast(inputs_col_num), out_row, out_col, output->data()); } + #ifdef PADDLE_WITH_HIP // Prevent the pinned memory value from being covered and release the memory // after the launch kernel of the stream is executed (reapply pinned memory -- GitLab