From cef884e518396335776749c89c318351d64be5d8 Mon Sep 17 00:00:00 2001 From: Pei Yang Date: Fri, 20 Sep 2019 16:48:54 +0800 Subject: [PATCH] refine concat cuda kernel, test=develop (#2081) --- lite/kernels/cuda/concat_compute.cu | 271 ++++-------------- lite/kernels/cuda/concat_compute_test.cc | 13 +- .../cuda/elementwise_add_compute_test.cc | 1 - lite/kernels/cuda/nearest_interp_compute.cu | 4 +- 4 files changed, 58 insertions(+), 231 deletions(-) diff --git a/lite/kernels/cuda/concat_compute.cu b/lite/kernels/cuda/concat_compute.cu index 10a9414935..a50b9d270f 100644 --- a/lite/kernels/cuda/concat_compute.cu +++ b/lite/kernels/cuda/concat_compute.cu @@ -21,134 +21,25 @@ namespace kernels { namespace cuda { using Tensor = lite::Tensor; -template -__global__ void ConcatKernel(const T** inputs, - const int* input_cols, - int col_size, - const int output_rows, - const int output_cols, - T* output) { - int tid_x = blockIdx.x * blockDim.x + threadIdx.x; - int curr_segment = 0; - int curr_offset = input_cols[0]; - for (; tid_x < output_cols; tid_x += blockDim.x * gridDim.x) { - int curr_col_offset = input_cols[curr_segment + 1]; - while (curr_col_offset <= tid_x) { - curr_offset = curr_col_offset; - ++curr_segment; - curr_col_offset = input_cols[curr_segment + 1]; - } - - int local_col = tid_x - curr_offset; - int segment_width = curr_col_offset - curr_offset; - - const T* input_ptr = inputs[curr_segment]; - int tid_y = blockIdx.y * blockDim.y + threadIdx.y; - for (; tid_y < output_rows; tid_y += blockDim.y * gridDim.y) - output[tid_y * output_cols + tid_x] = - input_ptr[tid_y * segment_width + local_col]; - } -} - -template -__device__ void ConcatKernelDetail(const T** inputs_data, - const int fixed_in_col, - const int out_rows, - const int out_cols, - T* output_data) { - int tid_x = blockIdx.x * blockDim.x + threadIdx.x; - for (; tid_x < out_cols; tid_x += blockDim.x * gridDim.x) { - int split = tid_x * 1.0 / fixed_in_col; - int in_offset = tid_x - split * fixed_in_col; - const T* input_ptr = inputs_data[split]; - int tid_y = blockIdx.y * blockDim.y + threadIdx.y; - for (; tid_y < out_rows; tid_y += blockDim.y * gridDim.y) { - output_data[tid_y * out_cols + tid_x] = - input_ptr[tid_y * fixed_in_col + in_offset]; - } - } - // for (int i = 0; i < 4; i++){ - // printf("input[0][%d] = %.1f\n", i, inputs_data[0][i]); - // printf("output[%d] = %.1f\n", i, output_data[i]); - // } -} - -template -__global__ void ConcatKernel(const T* input_addr0, - const T* input_addr1, - const int fixed_in_col, - const int out_rows, - const int out_cols, - T* output_data) { - const T* inputs_data[2]; - inputs_data[0] = input_addr0; - inputs_data[1] = input_addr1; - ConcatKernelDetail( - inputs_data, fixed_in_col, out_rows, out_cols, output_data); -} - -template -__global__ void ConcatKernel(const T* input_addr0, - const T* input_addr1, - const T* input_addr2, - const int fixed_in_col, - const int out_rows, - const int out_cols, - T* output_data) { - const T* inputs_data[3]; - inputs_data[0] = input_addr0; - inputs_data[1] = input_addr1; - inputs_data[2] = input_addr2; - ConcatKernelDetail( - inputs_data, fixed_in_col, out_rows, out_cols, output_data); -} - -template -__global__ void ConcatKernel(const T* input_addr0, - const T* input_addr1, - const T* input_addr2, - const T* input_addr3, - const int fixed_in_col, - const int out_rows, - const int out_cols, - T* output_data) { - const T* inputs_data[4]; - inputs_data[0] = input_addr0; - inputs_data[1] = input_addr1; - inputs_data[2] = input_addr2; - inputs_data[3] = input_addr3; - ConcatKernelDetail( - inputs_data, fixed_in_col, out_rows, out_cols, output_data); -} - -template -__global__ void ConcatKernel(const T** inputs_data, - const int in_num, - const int fixed_in_col, - const int out_rows, - const int out_cols, - T* output_data) { - ConcatKernelDetail( - inputs_data, fixed_in_col, out_rows, out_cols, output_data); -} - -static inline void GetBlockDims(const CUDAContext& context, - int num_rows, - int num_cols, - dim3* block_dims, - dim3* grid_dims) { - // Set the thread block and grid according to CurrentDeviceId - const int kThreadsPerBlock = 1024; - int block_cols = kThreadsPerBlock; - if (num_cols < kThreadsPerBlock) { // block_cols is aligned by 32. - block_cols = ((num_cols + 31) >> 5) << 5; +template +__global__ void Concat(const int num, + const Dtype* in_data, + const int num_concats, + const int concat_size, + const int top_concat_axis, + const int bottom_concat_axis, + const int offset_concat_axis, + Dtype* out_data) { + int index = blockIdx.x * blockDim.x + threadIdx.x; + if (index < num) { + const int total_concat_size = concat_size * bottom_concat_axis; + const int concat_num = index / total_concat_size; + const int concat_index = index % total_concat_size; + const int top_index = + concat_index + + (concat_num * top_concat_axis + offset_concat_axis) * concat_size; + out_data[top_index] = in_data[index]; } - int block_rows = kThreadsPerBlock / block_cols; - *block_dims = dim3(block_cols, block_rows, 1); - - int grid_cols = (num_cols + block_cols - 1) / block_cols; - int grid_rows = std::max(num_rows / block_rows, 1); - *grid_dims = dim3(grid_cols, grid_rows, 1); } void ConcatCompute::Run() { @@ -158,105 +49,40 @@ void ConcatCompute::Run() { std::vector input = param.x; Tensor* output = param.output; + auto* output_data = output->mutable_data(TARGET(kCUDA)); int axis = param.axis; - - int in_num = input.size(); - int in_row = 1; - auto dim_0 = input[0]->dims(); - for (int i = 0; i < axis; ++i) { - in_row *= dim_0[i]; + int inner_size = 1; + int outer_size = 1; + auto input_dims = input[0]->dims(); + for (int i = 0; i < axis; i++) { + outer_size *= input_dims[i]; } - int in_col = input[0]->numel() / in_row; - int out_row = in_row, out_col = 0; - std::vector inputs_data(in_num); - std::vector inputs_col(in_num + 1); - inputs_col[0] = 0; - bool has_same_shape = true; - for (int i = 0; i < in_num; ++i) { - int t_cols = input[i]->numel() / in_row; - if (has_same_shape) { - if (t_cols != in_col) has_same_shape = false; - } - out_col += t_cols; - inputs_col[i + 1] = out_col; - inputs_data[i] = input[i]->data(); - } - dim3 block_dims; - dim3 grid_dims; - GetBlockDims(ctx, out_row, out_col, &block_dims, &grid_dims); - const float** dev_ins_data = nullptr; - if (!has_same_shape || in_num < 2 || in_num > 4) { - float* tmp_dev_ins_data = nullptr; - CHECK(cudaSuccess == - cudaMalloc(&tmp_dev_ins_data, inputs_data.size() * sizeof(float*))); - CHECK(cudaSuccess == cudaMemcpy(tmp_dev_ins_data, - static_cast(inputs_data.data()), - inputs_data.size() * sizeof(float*), - cudaMemcpyHostToDevice)); - dev_ins_data = reinterpret_cast(tmp_dev_ins_data); + for (int i = axis + 1; i < input_dims.size(); i++) { + inner_size *= input_dims[i]; } - if (has_same_shape) { - if (in_num == 2) { - ConcatKernel<<>>( - inputs_data[0], - inputs_data[1], - in_col, - out_row, - out_col, - output->mutable_data()); - } else if (in_num == 3) { - ConcatKernel<<>>( - inputs_data[0], - inputs_data[1], - inputs_data[2], - in_col, - out_row, - out_col, - output->mutable_data()); - } else if (in_num == 4) { - ConcatKernel<<>>( - inputs_data[0], - inputs_data[1], - inputs_data[2], - inputs_data[3], - in_col, - out_row, - out_col, - output->mutable_data()); - } else { - ConcatKernel<<>>( - dev_ins_data, - in_num, - in_col, - out_row, - out_col, - output->mutable_data()); - cudaFree(dev_ins_data); - } - } else { - int* tmp_dev_ins_col_data = nullptr; - CHECK(cudaSuccess == - cudaMalloc(&tmp_dev_ins_col_data, inputs_col.size() * sizeof(int))); - CHECK(cudaSuccess == cudaMemcpy(tmp_dev_ins_col_data, - static_cast(inputs_col.data()), - inputs_col.size() * sizeof(int), - cudaMemcpyHostToDevice)); - int* dev_ins_col_data = static_cast(tmp_dev_ins_col_data); - ConcatKernel<<>>( - dev_ins_data, - dev_ins_col_data, - static_cast(inputs_col.size()), - out_row, - out_col, - output->mutable_data()); - cudaFree(dev_ins_data); - cudaFree(dev_ins_col_data); + int all_concat_axis = param.output->dims()[axis]; + int in_num = input.size(); + int offset_concat_axis = 0; + + for (int i = 0; i < in_num; i++) { + auto* input_data = input[i]->data(); + int input_concat_axis = input[i]->dims()[axis]; + int input_concat_size = input_concat_axis * inner_size; + int num = input_concat_size * outer_size; + int threads = 1024; + int blocks = (num + threads - 1) / threads; + Concat<<>>(num, + input_data, + outer_size, + inner_size, + all_concat_axis, + input_concat_axis, + offset_concat_axis, + output_data); + offset_concat_axis += input_concat_axis; } - - cudaError_t error = cudaGetLastError(); - if (error != cudaSuccess) LOG(INFO) << cudaGetErrorString(error); } } // namespace cuda @@ -270,7 +96,6 @@ REGISTER_LITE_KERNEL(concat, kNCHW, paddle::lite::kernels::cuda::ConcatCompute, def) - .BindInput("x", {LiteType::GetTensorTy(TARGET(kCUDA))}) - .BindInput("axis", {LiteType::GetTensorTy(TARGET(kCUDA))}) - .BindOutput("output", {LiteType::GetTensorTy(TARGET(kCUDA))}) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kCUDA))}) + .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kCUDA))}) .Finalize(); diff --git a/lite/kernels/cuda/concat_compute_test.cc b/lite/kernels/cuda/concat_compute_test.cc index 8dc097be48..254c1326f3 100644 --- a/lite/kernels/cuda/concat_compute_test.cc +++ b/lite/kernels/cuda/concat_compute_test.cc @@ -126,10 +126,10 @@ TEST(concat, compute_input_multi) { lite::Tensor tensorC_ref; lite::Tensor tensorD_ref; - DDimLite ddimA({1, 3, 1, 2}); - DDimLite ddimB({1, 4, 1, 2}); - DDimLite ddimC({1, 5, 1, 2}); - DDimLite ddimD({1, 6, 1, 2}); + DDimLite ddimA({1, 3, 38, 38}); + DDimLite ddimB({1, 4, 38, 38}); + DDimLite ddimC({1, 5, 38, 38}); + DDimLite ddimD({1, 6, 38, 38}); tensorA.Resize(ddimA); tensorB.Resize(ddimB); @@ -144,6 +144,9 @@ TEST(concat, compute_input_multi) { tensorC_ref.Resize(ddimC); tensorD_ref.Resize(ddimD); + out.Resize({1, 18, 38, 38}); + out_cpu.Resize({1, 18, 38, 38}); + out_ref.Resize({1, 18, 38, 38}); auto* out_data = out.mutable_data(TARGET(kCUDA)); auto* out_cpu_data = out_cpu.mutable_data(); auto* out_ref_data = out_ref.mutable_data(); @@ -215,7 +218,7 @@ TEST(concat, compute_input_multi) { concat_compute_ref(param_ref); LOG(INFO) << "concat_compute_ref end"; - for (int i = 0; i < out.numel(); i++) { + for (int i = 0; i < out_ref.numel(); i++) { EXPECT_NEAR(out_cpu_data[i], out_ref_data[i], 1e-5); } } diff --git a/lite/kernels/cuda/elementwise_add_compute_test.cc b/lite/kernels/cuda/elementwise_add_compute_test.cc index ea9998c8d3..f34f75961f 100644 --- a/lite/kernels/cuda/elementwise_add_compute_test.cc +++ b/lite/kernels/cuda/elementwise_add_compute_test.cc @@ -27,7 +27,6 @@ using Tensor = lite::Tensor; static void ElementwiseAddRef(float* x, float* y, float* out, int num) { for (int i = 0; i < num; ++i) { out[i] = x[i] + y[i]; - // LOG(INFO) << x[i] << " + " << y[i] << " = " << out[i]; } } diff --git a/lite/kernels/cuda/nearest_interp_compute.cu b/lite/kernels/cuda/nearest_interp_compute.cu index 8edeacfe5a..152872a8d2 100644 --- a/lite/kernels/cuda/nearest_interp_compute.cu +++ b/lite/kernels/cuda/nearest_interp_compute.cu @@ -120,9 +120,9 @@ void NearestInterpCompute::Run() { int in_chw = c * in_hw; int out_chw = c * out_hw; - int pixelNum = n * out_chw; + int pixel_num = n * out_chw; int threads = 512; - int blocks = (pixelNum + threads - 1) / threads; + int blocks = (pixel_num + threads - 1) / threads; blocks = blocks > 8 ? 8 : blocks; KeNearestNeighborInterp<<>>(input_data, -- GitLab