未验证 提交 8f593443 编写于 作者: W Wilber 提交者: GitHub

optimize softmax cuda kernel test=develop (#2660)

optimize softmax cuda kernel
上级 00fee283
...@@ -156,8 +156,8 @@ void SoftmaxCompute::PrepareForRun() { ...@@ -156,8 +156,8 @@ void SoftmaxCompute::PrepareForRun() {
cudaGetDevice(&device_id); cudaGetDevice(&device_id);
cudaDeviceProp deviceProp; cudaDeviceProp deviceProp;
cudaGetDeviceProperties(&deviceProp, device_id); cudaGetDeviceProperties(&deviceProp, device_id);
sharedmem_size = deviceProp.sharedMemPerBlock; sharedmem_size_ = deviceProp.sharedMemPerBlock;
max_dimsize = sharedmem_size / sizeof(float) / CUDA_NUM_THREADS; max_dimsize_ = sharedmem_size_ / sizeof(float) / CUDA_NUM_THREADS;
} }
void SoftmaxCompute::Run() { void SoftmaxCompute::Run() {
...@@ -174,29 +174,27 @@ void SoftmaxCompute::Run() { ...@@ -174,29 +174,27 @@ void SoftmaxCompute::Run() {
int outer_num = x_dims.Slice(0, axis).production(); int outer_num = x_dims.Slice(0, axis).production();
int inner_num = x_dims.Slice(axis + 1, x_rank).production(); int inner_num = x_dims.Slice(axis + 1, x_rank).production();
int total_threads = inner_num * outer_num; int total_threads = inner_num * outer_num;
int axis_size = x_dims[axis]; axis_size_ = x_dims[axis];
const int threads = CUDA_NUM_THREADS; const int threads = CUDA_NUM_THREADS;
const int blocks = (total_threads + threads - 1) / threads; const int blocks = (total_threads + threads - 1) / threads;
auto input_data = param.x->data<float>(); auto input_data = param.x->data<float>();
auto output_data = param.output->mutable_data<float>(TARGET(kCUDA)); auto output_data = param.output->mutable_data<float>(TARGET(kCUDA));
if (axis_size <= max_dimsize) { if (axis_size_ <= max_dimsize_) {
int use_sharemem_size = axis_size * threads * sizeof(float); int use_sharemem_size = axis_size_ * threads * sizeof(float);
sharemem_softmax_kernel<<<blocks, threads, use_sharemem_size, stream>>>( sharemem_softmax_kernel<<<blocks, threads, use_sharemem_size, stream>>>(
total_threads, total_threads,
input_data, input_data,
output_data, output_data,
inner_num, inner_num,
outer_num, outer_num,
axis_size); axis_size_);
} else { } else {
//! re_alloc device memory //! re_alloc device memory
Tensor tmax_data; tmax_data_.Resize({1, 1, 1, outer_num * inner_num});
Tensor tsum_data; tsum_data_.Resize({1, 1, 1, outer_num * inner_num});
tmax_data.Resize({1, 1, 1, outer_num * inner_num}); auto max_data = tmax_data_.mutable_data<float>(TARGET(kCUDA));
tsum_data.Resize({1, 1, 1, outer_num * inner_num}); auto sum_data = tsum_data_.mutable_data<float>(TARGET(kCUDA));
auto max_data = tmax_data.mutable_data<float>(TARGET(kCUDA));
auto sum_data = tsum_data.mutable_data<float>(TARGET(kCUDA));
//! firstly, get maximum data //! firstly, get maximum data
float min_data = std::numeric_limits<float>::lowest(); float min_data = std::numeric_limits<float>::lowest();
softmax_max_kernel<float><<<blocks, threads, 0, stream>>>(total_threads, softmax_max_kernel<float><<<blocks, threads, 0, stream>>>(total_threads,
...@@ -205,7 +203,7 @@ void SoftmaxCompute::Run() { ...@@ -205,7 +203,7 @@ void SoftmaxCompute::Run() {
min_data, min_data,
inner_num, inner_num,
outer_num, outer_num,
axis_size); axis_size_);
//! then, compute exp and sum data //! then, compute exp and sum data
softmax_sub_exp_sum_kernel<float><<<blocks, threads, 0, stream>>>( softmax_sub_exp_sum_kernel<float><<<blocks, threads, 0, stream>>>(
total_threads, total_threads,
...@@ -215,10 +213,10 @@ void SoftmaxCompute::Run() { ...@@ -215,10 +213,10 @@ void SoftmaxCompute::Run() {
sum_data, sum_data,
inner_num, inner_num,
outer_num, outer_num,
axis_size); axis_size_);
//! last, compute divided output //! last, compute divided output
softmax_divid_output_kernel<float><<<blocks, threads, 0, stream>>>( softmax_divid_output_kernel<float><<<blocks, threads, 0, stream>>>(
total_threads, output_data, sum_data, inner_num, outer_num, axis_size); total_threads, output_data, sum_data, inner_num, outer_num, axis_size_);
} }
cudaError_t error = cudaGetLastError(); cudaError_t error = cudaGetLastError();
if (error != cudaSuccess) LOG(ERROR) << cudaGetErrorString(error); if (error != cudaSuccess) LOG(ERROR) << cudaGetErrorString(error);
......
...@@ -30,9 +30,11 @@ class SoftmaxCompute ...@@ -30,9 +30,11 @@ class SoftmaxCompute
virtual ~SoftmaxCompute() = default; virtual ~SoftmaxCompute() = default;
private: private:
size_t sharedmem_size; lite::Tensor tmax_data_;
int num_threads; lite::Tensor tsum_data_;
int max_dimsize; size_t sharedmem_size_;
int max_dimsize_;
int axis_size_;
}; };
} // namespace cuda } // namespace cuda
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册