From 0511794e69726f0b1f2090893bea1208ad3f3d26 Mon Sep 17 00:00:00 2001 From: Wilber Date: Wed, 25 Dec 2019 11:25:38 +0800 Subject: [PATCH] optimize softmax cuda kernel test=develop (#2660) optimize softmax cuda kernel --- lite/kernels/cuda/softmax_compute.cu | 28 +++++++++++++--------------- lite/kernels/cuda/softmax_compute.h | 8 +++++--- 2 files changed, 18 insertions(+), 18 deletions(-) diff --git a/lite/kernels/cuda/softmax_compute.cu b/lite/kernels/cuda/softmax_compute.cu index 157c6ae889..431bd6eb56 100644 --- a/lite/kernels/cuda/softmax_compute.cu +++ b/lite/kernels/cuda/softmax_compute.cu @@ -156,8 +156,8 @@ void SoftmaxCompute::PrepareForRun() { cudaGetDevice(&device_id); cudaDeviceProp deviceProp; cudaGetDeviceProperties(&deviceProp, device_id); - sharedmem_size = deviceProp.sharedMemPerBlock; - max_dimsize = sharedmem_size / sizeof(float) / CUDA_NUM_THREADS; + sharedmem_size_ = deviceProp.sharedMemPerBlock; + max_dimsize_ = sharedmem_size_ / sizeof(float) / CUDA_NUM_THREADS; } void SoftmaxCompute::Run() { @@ -174,29 +174,27 @@ void SoftmaxCompute::Run() { int outer_num = x_dims.Slice(0, axis).production(); int inner_num = x_dims.Slice(axis + 1, x_rank).production(); int total_threads = inner_num * outer_num; - int axis_size = x_dims[axis]; + axis_size_ = x_dims[axis]; const int threads = CUDA_NUM_THREADS; const int blocks = (total_threads + threads - 1) / threads; auto input_data = param.x->data(); auto output_data = param.output->mutable_data(TARGET(kCUDA)); - if (axis_size <= max_dimsize) { - int use_sharemem_size = axis_size * threads * sizeof(float); + if (axis_size_ <= max_dimsize_) { + int use_sharemem_size = axis_size_ * threads * sizeof(float); sharemem_softmax_kernel<<>>( total_threads, input_data, output_data, inner_num, outer_num, - axis_size); + axis_size_); } else { //! re_alloc device memory - Tensor tmax_data; - Tensor tsum_data; - tmax_data.Resize({1, 1, 1, outer_num * inner_num}); - tsum_data.Resize({1, 1, 1, outer_num * inner_num}); - auto max_data = tmax_data.mutable_data(TARGET(kCUDA)); - auto sum_data = tsum_data.mutable_data(TARGET(kCUDA)); + tmax_data_.Resize({1, 1, 1, outer_num * inner_num}); + tsum_data_.Resize({1, 1, 1, outer_num * inner_num}); + auto max_data = tmax_data_.mutable_data(TARGET(kCUDA)); + auto sum_data = tsum_data_.mutable_data(TARGET(kCUDA)); //! firstly, get maximum data float min_data = std::numeric_limits::lowest(); softmax_max_kernel<<>>(total_threads, @@ -205,7 +203,7 @@ void SoftmaxCompute::Run() { min_data, inner_num, outer_num, - axis_size); + axis_size_); //! then, compute exp and sum data softmax_sub_exp_sum_kernel<<>>( total_threads, @@ -215,10 +213,10 @@ void SoftmaxCompute::Run() { sum_data, inner_num, outer_num, - axis_size); + axis_size_); //! last, compute divided output softmax_divid_output_kernel<<>>( - total_threads, output_data, sum_data, inner_num, outer_num, axis_size); + total_threads, output_data, sum_data, inner_num, outer_num, axis_size_); } cudaError_t error = cudaGetLastError(); if (error != cudaSuccess) LOG(ERROR) << cudaGetErrorString(error); diff --git a/lite/kernels/cuda/softmax_compute.h b/lite/kernels/cuda/softmax_compute.h index 72d43a8eff..e563b36178 100644 --- a/lite/kernels/cuda/softmax_compute.h +++ b/lite/kernels/cuda/softmax_compute.h @@ -30,9 +30,11 @@ class SoftmaxCompute virtual ~SoftmaxCompute() = default; private: - size_t sharedmem_size; - int num_threads; - int max_dimsize; + lite::Tensor tmax_data_; + lite::Tensor tsum_data_; + size_t sharedmem_size_; + int max_dimsize_; + int axis_size_; }; } // namespace cuda -- GitLab