From 922ace19a45f30075618f71428523e7a2d5898d6 Mon Sep 17 00:00:00 2001 From: Wilber Date: Sun, 19 Jan 2020 19:02:09 +0800 Subject: [PATCH] fix bug for crmm model test=develop (#2786) - modify aligned_matmul kernel for dynamically malloc memory - fix top_k_avg_pooling kernel to support data whose size is more than cuda shared memory. --- .../cuda/search_aligned_mat_mul_compute.h | 25 +-- .../cuda/sequence_topk_avg_pooling_compute.cu | 153 ++++++++++++++++-- .../cuda/sequence_topk_avg_pooling_compute.h | 3 + 3 files changed, 150 insertions(+), 31 deletions(-) diff --git a/lite/kernels/cuda/search_aligned_mat_mul_compute.h b/lite/kernels/cuda/search_aligned_mat_mul_compute.h index b1c4552d9c..8304b0f2b4 100644 --- a/lite/kernels/cuda/search_aligned_mat_mul_compute.h +++ b/lite/kernels/cuda/search_aligned_mat_mul_compute.h @@ -31,21 +31,12 @@ class SearchAlignedMatMulCompute using param_t = operators::MatMulParam; void PrepareForRun() override { - auto& param = this->Param(); - CHECK(ctx_) << "running context should be set first"; - auto& cuda_ctx = ctx_->template As(); - bool x_transpose = param.transpose_X; - bool y_transpose = param.transpose_Y; - int seq_num = param.X->lod()[0].size() - 1; batched_gemm_impl_.reset(new lite::cuda::math::BatchedGemm); - CHECK( - batched_gemm_impl_->init(x_transpose, y_transpose, seq_num, &cuda_ctx)); - A_ = static_cast(malloc(3 * seq_num * sizeof(float*))); - CHECK(A_); } void Run() override { auto& param = this->Param(); + auto& cuda_ctx = ctx_->template As(); auto x = param.X; auto y = param.Y; auto out = param.Out; @@ -76,25 +67,25 @@ class SearchAlignedMatMulCompute auto x_stride = x_batch_size * x_inner_size; auto y_stride = y_batch_size * y_inner_size; auto out_stride = M * N; - for (int seq = 0; seq < seq_num; seq++) { + + float* A_[seq_num * 3]; + for (int seq = 0; seq < seq_num; ++seq) { A_[seq] = const_cast(x_data) + seq * x_stride; A_[seq + seq_num] = const_cast(y_data) + seq * y_stride; A_[seq + seq_num * 2] = out_data + seq * out_stride; } + + CHECK( + batched_gemm_impl_->init(x_transpose, y_transpose, seq_num, &cuda_ctx)); batched_gemm_impl_->run( alpha, 0.0f, const_cast(A_), M, N, K, seq_num); } - ~SearchAlignedMatMulCompute() { - if (A_ != nullptr) { - free(A_); - } - } + ~SearchAlignedMatMulCompute() { batched_gemm_impl_.reset(); } private: std::unique_ptr> batched_gemm_impl_; - float** A_{nullptr}; }; } // namespace cuda diff --git a/lite/kernels/cuda/sequence_topk_avg_pooling_compute.cu b/lite/kernels/cuda/sequence_topk_avg_pooling_compute.cu index 4794644c6d..25ea6b2ea9 100644 --- a/lite/kernels/cuda/sequence_topk_avg_pooling_compute.cu +++ b/lite/kernels/cuda/sequence_topk_avg_pooling_compute.cu @@ -93,6 +93,111 @@ __global__ void topk_avg_pooling_kernel_by_row_improve( } } +template +__global__ void topk_avg_pooling_kernel_for_big_data( + Dtype *output_data, + const Dtype *input_data, + const int *gpu_input_offset_l, + const int *gpu_input_offset_r, + const int row_max, + const int col_max, + const int topk_size, + const int *topks, + const int feat_map_num, + const int actual_row_in_shared_mem) { + int row = gpu_input_offset_l[blockIdx.x + 1] - + gpu_input_offset_l[blockIdx.x]; // 75 + int col = gpu_input_offset_r[blockIdx.x + 1] - + gpu_input_offset_r[blockIdx.x]; // 300 + + int max_k = topks[topk_size - 1]; + max_k = max_k < col ? max_k : col; + + extern __shared__ Dtype smem[]; // H1*W or H2*W ... + + int filled_z = row / actual_row_in_shared_mem; + int remain_row = row - filled_z * actual_row_in_shared_mem; + + if (blockIdx.z > filled_z || (blockIdx.z == filled_z && remain_row == 0)) { + return; + } + + const Dtype *fm_row_in_data = input_data + + blockIdx.x * row_max * feat_map_num * col_max + + blockIdx.y * row_max * col_max + + blockIdx.z * actual_row_in_shared_mem * col_max; + if (blockIdx.z == filled_z) { + for (int i = threadIdx.x; i < remain_row * col_max; i += blockDim.x) { + smem[i] = fm_row_in_data[i]; + } + } else { + for (int i = threadIdx.x; i < actual_row_in_shared_mem * col_max; + i += blockDim.x) { + smem[i] = fm_row_in_data[i]; + } + } + __syncthreads(); + + int cur_row; + if (blockIdx.z == filled_z) { + cur_row = remain_row; + } else { + cur_row = actual_row_in_shared_mem; + } + + for (int idx = threadIdx.x; idx < cur_row; idx += blockDim.x) { + Dtype *fm_row_out_data = output_data + + (gpu_input_offset_l[blockIdx.x] + + blockIdx.z * actual_row_in_shared_mem + idx) * + feat_map_num * topk_size + + blockIdx.y * topk_size; + + Dtype *smem_start_col = smem + idx * col_max; + + int counter = max_k; // topk_size; + Dtype last_max_val = -20000.0; + while (counter) { + Dtype max_val = -10000.0; + int max_pos = 0; // -1; + int m = 0; + for (; m < col; m++) { + Dtype cur_data = smem_start_col[m]; + if (cur_data > max_val) { + max_val = cur_data; + max_pos = m; + last_max_val = max_val; + } + } + if (max_val < -9999.0) { // == -10000.0 + max_val = last_max_val; + } + smem_start_col[max_pos] = -10000000.0; + + int i = max_k - counter; + for (int c = 0; c < topk_size; c++) { + if (i <= topks[c] - 1) { + fm_row_out_data[c] += max_val; + } + } + counter--; + } + __syncthreads(); + // compute avg + for (int i = 0; i < topk_size; i++) { + fm_row_out_data[i] = fm_row_out_data[i] / topks[i]; + } + } +} + +template +void SequenceTopkAvgPoolingCompute::PrepareForRun() { + int device_id; + cudaGetDevice(&device_id); + cudaDeviceProp deviceProp; + cudaGetDeviceProperties(&deviceProp, device_id); + _shared_mem_size = deviceProp.sharedMemPerBlock; +} + template void SequenceTopkAvgPoolingCompute::Run() { auto ¶m = this->Param(); @@ -156,20 +261,40 @@ void SequenceTopkAvgPoolingCompute::Run() { int feat_map_size = height * width; - dim3 blocks(num, channel); - dim3 threads(32, 1); - - topk_avg_pooling_kernel_by_row_improve< - T><<>>( - out_data, - in_data, - height_offset, - width_offset, - height, - width, - param.topks.size(), - _top_ks.data(), - param.channel_num); + if (feat_map_size * sizeof(T) <= _shared_mem_size) { + dim3 blocks(num, channel); + dim3 threads(32, 1); + + topk_avg_pooling_kernel_by_row_improve< + T><<>>( + out_data, + in_data, + height_offset, + width_offset, + height, + width, + param.topks.size(), + _top_ks.data(), + param.channel_num); + } else { + int actual_row = _shared_mem_size / width / sizeof(T); + int num_z = (height + actual_row - 1) / actual_row; + dim3 blocks(num, channel, num_z); + dim3 threads(32, 1); + + topk_avg_pooling_kernel_for_big_data< + T><<>>( + out_data, + in_data, + height_offset, + width_offset, + height, + width, + param.topks.size(), + _top_ks.data(), + param.channel_num, + actual_row); + } } } // namespace cuda diff --git a/lite/kernels/cuda/sequence_topk_avg_pooling_compute.h b/lite/kernels/cuda/sequence_topk_avg_pooling_compute.h index 321ec9cfce..6f4be12f0f 100644 --- a/lite/kernels/cuda/sequence_topk_avg_pooling_compute.h +++ b/lite/kernels/cuda/sequence_topk_avg_pooling_compute.h @@ -29,12 +29,15 @@ class SequenceTopkAvgPoolingCompute void Run() override; + void PrepareForRun() override; + virtual ~SequenceTopkAvgPoolingCompute() = default; protected: lite::Tensor _height_offset; lite::Tensor _width_offset; lite::Tensor _top_ks; + int _shared_mem_size; }; } // namespace cuda -- GitLab