From 8776f751d56425771e8fa74e2a689d9ce180ab51 Mon Sep 17 00:00:00 2001 From: Wilber Date: Tue, 4 Aug 2020 11:04:54 +0800 Subject: [PATCH] [CUDA] Add cudnn_softmax. (#4040) --- lite/backends/cuda/math/CMakeLists.txt | 6 +- lite/backends/cuda/math/cudnn_conv.cc | 29 +- lite/backends/cuda/math/cudnn_helper.cc | 38 +++ lite/backends/cuda/math/cudnn_helper.h | 13 +- lite/backends/cuda/math/cudnn_softmax.cc | 105 +++++++ lite/backends/cuda/math/cudnn_softmax.h | 64 ++++ lite/kernels/cuda/CMakeLists.txt | 2 +- lite/kernels/cuda/softmax_compute.cu | 341 +++++++++++++++++----- lite/kernels/cuda/softmax_compute.h | 5 +- lite/kernels/cuda/softmax_compute_test.cc | 262 +++++++++++------ lite/operators/op_params.h | 1 + lite/operators/softmax_op.cc | 3 + 12 files changed, 683 insertions(+), 186 deletions(-) create mode 100644 lite/backends/cuda/math/cudnn_helper.cc create mode 100644 lite/backends/cuda/math/cudnn_softmax.cc create mode 100644 lite/backends/cuda/math/cudnn_softmax.h diff --git a/lite/backends/cuda/math/CMakeLists.txt b/lite/backends/cuda/math/CMakeLists.txt index c23d3d0ed0..14e5ae3840 100644 --- a/lite/backends/cuda/math/CMakeLists.txt +++ b/lite/backends/cuda/math/CMakeLists.txt @@ -8,7 +8,9 @@ nv_library(cuda_activation SRCS activation.cu DEPS ${cuda_static_deps}) nv_library(cuda_scale SRCS scale.cu DEPS ${cuda_static_deps}) nv_library(cuda_type_trans SRCS type_trans.cu DEPS ${cuda_static_deps}) nv_library(cuda_transpose SRCS transpose.cu DEPS ${cuda_static_deps}) -nv_library(cudnn_conv SRCS cudnn_conv.cc DEPS cuda_activation cuda_scale cuda_type_trans ${cuda_static_deps}) +nv_library(cudnn_helper SRCS cudnn_helper.cc DEPS ${cuda_static_deps}) +nv_library(cudnn_conv SRCS cudnn_conv.cc DEPS cuda_activation cuda_scale cuda_type_trans cudnn_helper ${cuda_static_deps}) +nv_library(cudnn_softmax SRCS cudnn_softmax.cc DEPS cudnn_helper ${cuda_static_deps} tensor) nv_library(cuda_elementwise SRCS elementwise.cu DEPS ${cuda_static_deps}) nv_library(cudnn_pool SRCS cudnn_pool.cc DEPS ${cuda_static_deps}) nv_library(cuda_gru_forward SRCS gru_forward.cu DEPS cuda_activation ${cuda_static_deps}) @@ -22,6 +24,7 @@ nv_library(cuda_bias SRCS bias.cu DEPS ${cuda_static_deps}) set ( math_cuda cudnn_conv + cudnn_softmax cuda_activation cuda_scale cuda_type_trans @@ -35,6 +38,7 @@ set ( cuda_strided_gemm cuda_sequence_padding cuda_bias + cudnn_helper ) set(math_cuda "${math_cuda}" CACHE GLOBAL "math cuda") diff --git a/lite/backends/cuda/math/cudnn_conv.cc b/lite/backends/cuda/math/cudnn_conv.cc index 5db41302c0..786ca33a18 100644 --- a/lite/backends/cuda/math/cudnn_conv.cc +++ b/lite/backends/cuda/math/cudnn_conv.cc @@ -15,6 +15,7 @@ #include "lite/backends/cuda/math/cudnn_conv.h" #include "lite/backends/cuda/math/activation.h" #include "lite/backends/cuda/math/conv_op_cache_cudnn.h" +#include "lite/backends/cuda/math/cudnn_helper.h" #include "lite/backends/cuda/math/scale.h" #include "lite/backends/cuda/math/type_trans.h" @@ -23,19 +24,6 @@ namespace lite { namespace cuda { namespace math { -template -cudnnDataType_t GetDataType(); - -template <> -cudnnDataType_t GetDataType() { - return CUDNN_DATA_FLOAT; -} - -template <> -cudnnDataType_t GetDataType() { - return CUDNN_DATA_HALF; -} - template bool CudnnConv2D::create(const operators::ConvParam& param, Context* ctx) { @@ -67,13 +55,13 @@ bool CudnnConv2D::create(const operators::ConvParam& param, CUDNN_CHECK(cudnnSetTensor4dDescriptor(this->input_desc_, CUDNN_TENSOR_NCHW, - GetDataType(), + GetCudnnDataType(), batch, ic, ih, iw)); CUDNN_CHECK(cudnnSetFilter4dDescriptor(this->filter_desc_, - GetDataType(), + GetCudnnDataType(), CUDNN_TENSOR_NCHW, oc, ic / param.groups, @@ -87,11 +75,11 @@ bool CudnnConv2D::create(const operators::ConvParam& param, dh, dw, CUDNN_CROSS_CORRELATION, - GetDataType())); + GetCudnnDataType())); CUDNN_CHECK(cudnnSetConvolutionGroupCount(this->conv_desc_, param.groups)); CUDNN_CHECK(cudnnSetTensor4dDescriptor(this->output_desc_, CUDNN_TENSOR_NCHW, - GetDataType(), + GetCudnnDataType(), batch, oc, oh, @@ -190,8 +178,11 @@ bool CudnnConv2D::create(const operators::ConvParam& param, if (param.bias) { int dim_bias[] = {1, oc, 1, 1}; int stride_bias[] = {oc, 1, 1, 1}; - cudnnSetTensorNdDescriptor( - this->bias_desc_, GetDataType(), 4, dim_bias, stride_bias); + cudnnSetTensorNdDescriptor(this->bias_desc_, + GetCudnnDataType(), + 4, + dim_bias, + stride_bias); } return true; } diff --git a/lite/backends/cuda/math/cudnn_helper.cc b/lite/backends/cuda/math/cudnn_helper.cc new file mode 100644 index 0000000000..92cb320961 --- /dev/null +++ b/lite/backends/cuda/math/cudnn_helper.cc @@ -0,0 +1,38 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/backends/cuda/math/cudnn_helper.h" + +#include +#include + +namespace paddle { +namespace lite { +namespace cuda { +namespace math { + +template <> +cudnnDataType_t GetCudnnDataType() { + return CUDNN_DATA_FLOAT; +} + +template <> +cudnnDataType_t GetCudnnDataType() { + return CUDNN_DATA_HALF; +} + +} // namespace math +} // namespace cuda +} // namespace lite +} // namespace paddle diff --git a/lite/backends/cuda/math/cudnn_helper.h b/lite/backends/cuda/math/cudnn_helper.h index b7f9b2cf69..972d841d97 100644 --- a/lite/backends/cuda/math/cudnn_helper.h +++ b/lite/backends/cuda/math/cudnn_helper.h @@ -13,12 +13,23 @@ // limitations under the License. #pragma once +#include + #include +#include + +#include "lite/api/paddle_place.h" +#include "lite/backends/cuda/cuda_utils.h" namespace paddle { namespace lite { namespace cuda { -namespace math {} // namespace math +namespace math { + +template +cudnnDataType_t GetCudnnDataType(); + +} // namespace math } // namespace cuda } // namespace lite } // namespace paddle diff --git a/lite/backends/cuda/math/cudnn_softmax.cc b/lite/backends/cuda/math/cudnn_softmax.cc new file mode 100644 index 0000000000..5aafc519ac --- /dev/null +++ b/lite/backends/cuda/math/cudnn_softmax.cc @@ -0,0 +1,105 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/backends/cuda/math/cudnn_softmax.h" + +#include "lite/backends/cuda/math/cudnn_helper.h" + +namespace paddle { +namespace lite { +namespace cuda { +namespace math { + +template +bool CudnnSoftmax::Init(const operators::SoftmaxParam& param, + Context* ctx) { + this->stream_ = ctx->exec_stream(); + CUDNN_CHECK(cudnnCreate(&this->handle_)); + CUDNN_CHECK(cudnnSetStream(this->handle_, this->stream_)); + + cudnnCreateTensorDescriptor(&this->bottom_desc_); + cudnnCreateTensorDescriptor(&this->top_desc_); + + return Create(param, ctx); +} + +template +bool CudnnSoftmax::Create(const operators::SoftmaxParam& param, + Context* ctx) { + int axis = param.axis; + if (axis < 0) { + axis += param.x->dims().size(); + } + int outer_num = param.x->dims().count(0, axis); + int inner_num = param.x->dims().count(axis + 1, param.x->dims().size()); + + int N = outer_num; + int C = param.x->dims()[axis]; + int H = inner_num; + int W = 1; + + const int stride_w = 1; + const int stride_h = W * stride_w; + const int stride_c = H * stride_h; + const int stride_n = C * stride_c; + CUDNN_CHECK(cudnnSetTensor4dDescriptorEx(bottom_desc_, + GetCudnnDataType(), + N, + C, + H, + W, + stride_n, + stride_c, + stride_h, + stride_w)); + CUDNN_CHECK(cudnnSetTensor4dDescriptorEx(top_desc_, + GetCudnnDataType(), + N, + C, + H, + W, + stride_n, + stride_c, + stride_h, + stride_w)); + handle_setup_ = true; + return true; +} + +template +bool CudnnSoftmax::Run(const operators::SoftmaxParam& param) { + T* output_data = param.output->mutable_data(TARGET(kCUDA)); + const T* input_data = param.x->data(); + float alpha = 1.0f; + float beta = 0.f; + CUDNN_CHECK(cudnnSoftmaxForward(handle_, + CUDNN_SOFTMAX_ACCURATE, + CUDNN_SOFTMAX_MODE_CHANNEL, + &alpha, + bottom_desc_, + reinterpret_cast(input_data), + &beta, + top_desc_, + reinterpret_cast(output_data))); + + return true; +} + +template class CudnnSoftmax; +template class CudnnSoftmax; + +} // namespace math +} // namespace cuda +} // namespace lite +} // namespace paddle diff --git a/lite/backends/cuda/math/cudnn_softmax.h b/lite/backends/cuda/math/cudnn_softmax.h new file mode 100644 index 0000000000..87200bb695 --- /dev/null +++ b/lite/backends/cuda/math/cudnn_softmax.h @@ -0,0 +1,64 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include + +#include +#include + +#include "lite/api/paddle_place.h" +#include "lite/backends/cuda/cuda_utils.h" +#include "lite/core/context.h" +#include "lite/core/target_wrapper.h" +#include "lite/operators/op_params.h" + +namespace paddle { +namespace lite { +namespace cuda { +namespace math { + +template +class CudnnSoftmax { + public: + CudnnSoftmax() + : handle_(nullptr), + bottom_desc_(nullptr), + top_desc_(nullptr), + handle_setup_(false) {} + virtual ~CudnnSoftmax() { + if (!handle_setup_) return; + cudnnDestroyTensorDescriptor(bottom_desc_); + cudnnDestroyTensorDescriptor(top_desc_); + cudnnDestroy(handle_); + } + bool Init(const operators::SoftmaxParam& param, Context* ctx); + + bool Create(const operators::SoftmaxParam& param, + Context* ctx); + + bool Run(const operators::SoftmaxParam& param); + + private: + cudaStream_t stream_; + cudnnHandle_t handle_; + cudnnTensorDescriptor_t bottom_desc_; + cudnnTensorDescriptor_t top_desc_; + bool handle_setup_; +}; + +} // namespace math +} // namespace cuda +} // namespace lite +} // namespace paddle diff --git a/lite/kernels/cuda/CMakeLists.txt b/lite/kernels/cuda/CMakeLists.txt index 3d396cfa12..4a106f33b1 100644 --- a/lite/kernels/cuda/CMakeLists.txt +++ b/lite/kernels/cuda/CMakeLists.txt @@ -30,7 +30,7 @@ add_kernel(feed_compute_cuda CUDA basic SRCS feed_compute.cc DEPS ${lite_kernel_ add_kernel(fetch_compute_cuda CUDA basic SRCS fetch_compute.cc DEPS ${lite_kernel_deps}) add_kernel(scale_compute_cuda CUDA basic SRCS scale_compute.cc DEPS ${lite_kernel_deps} cuda_scale) add_kernel(dropout_compute_cuda CUDA basic SRCS dropout_compute.cc DEPS ${lite_kernel_deps} cuda_scale) -add_kernel(softmax_compute_cuda CUDA basic SRCS softmax_compute.cu DEPS ${lite_kernel_deps}) +add_kernel(softmax_compute_cuda CUDA basic SRCS softmax_compute.cu DEPS ${lite_kernel_deps} ${math_cuda}) add_kernel(pool_compute_cuda CUDA basic SRCS pool_compute.cu DEPS ${lite_kernel_deps} cudnn_pool) add_kernel(bilinear_interp_compute_cuda CUDA basic SRCS bilinear_interp_compute.cu DEPS ${lite_kernel_deps}) diff --git a/lite/kernels/cuda/softmax_compute.cu b/lite/kernels/cuda/softmax_compute.cu index 431bd6eb56..ad0ebe44da 100644 --- a/lite/kernels/cuda/softmax_compute.cu +++ b/lite/kernels/cuda/softmax_compute.cu @@ -12,6 +12,8 @@ limitations under the License. */ #pragma once #include #include + +#include "lite/backends/cuda/cuda_utils.h" #include "lite/core/op_registry.h" #include "lite/kernels/cuda/softmax_compute.h" @@ -21,8 +23,6 @@ namespace kernels { namespace cuda { using Tensor = lite::Tensor; -const int CUDA_NUM_THREADS = 512; - extern __shared__ char tile[]; template __global__ void sharemem_softmax_kernel(int total_size, @@ -73,6 +73,69 @@ __global__ void sharemem_softmax_kernel(int total_size, } } +template <> +__global__ void sharemem_softmax_kernel(int total_size, + const half* in_data, + half* out_data, + int inner_num, + int outer_num, + int axis_size) { + half* data = reinterpret_cast(tile) + threadIdx.x; + //! compute thread index and real data index + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < total_size) { + int idx_inner = idx % inner_num; + int idx_outer = (idx / inner_num) * axis_size; + int blocksize = blockDim.x; + int real_index = idx_outer * inner_num + idx_inner; + int loop_idx = real_index; +//! read all data to sharemem in softmax channel +#pragma unroll + for (int i = 0; i < axis_size; ++i) { + data[i * blocksize] = in_data[loop_idx]; + loop_idx += inner_num; + } + //! get maximum value in softmax channel + half max_data = data[0]; +#pragma unroll + for (int i = 1; i < axis_size; ++i) { + half dt = data[i * blocksize]; +#if __CUDA_ARCH__ >= 530 + if (__hlt(max_data, dt)) { +#else + if (__half2float(max_data) < __half2float(dt)) { +#endif + max_data = dt; + } + } + //! subtract then summarize + half sum = 0; +#pragma unroll + for (int i = 0; i < axis_size; ++i) { + half* dt = data + i * blocksize; +#if __CUDA_ARCH__ >= 530 + *dt = hexp(__hsub(*dt, max_data)); + sum = __hadd(sum, *dt); +#else + *dt = __float2half(expf(__half2float(*dt) - __half2float(max_data))); + sum = __float2half(__half2float(sum) + __half2float(*dt)); +#endif + } + //! write back result + loop_idx = real_index; +#pragma unroll + for (int i = 0; i < axis_size; ++i) { +#if __CUDA_ARCH__ >= 530 + out_data[loop_idx] = __hdiv(data[i * blocksize], sum); +#else + out_data[loop_idx] = + __float2half(__half2float(data[i * blocksize]) / __half2float(sum)); +#endif + loop_idx += inner_num; + } + } +} + //! general kernel for softmax template __global__ void softmax_max_kernel(int total_size, @@ -99,6 +162,38 @@ __global__ void softmax_max_kernel(int total_size, } } +template <> +__global__ void softmax_max_kernel(int total_size, + const half* in_data, + half* out_data, + half min_data, + int inner_num, + int outer_num, + int axis_size) { + //! compute data index + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < total_size) { + int idx_inner = idx % inner_num; + int idx_outer = (idx / inner_num) * axis_size; + int real_index = idx_outer * inner_num + idx_inner; + //! get maximum data across softmax axis + half max_data = min_data; + for (int i = 0; i < axis_size; ++i) { +#if __CUDA_ARCH__ >= 530 + max_data = + __hgt(in_data[real_index], max_data) ? in_data[real_index] : max_data; +#else + float a = __half2float(in_data[real_index]); + float b = __half2float(max_data); + float res = a > b ? a : b; + max_data = __float2half(res); +#endif + real_index += inner_num; + } + out_data[idx] = max_data; + } +} + template __global__ void softmax_sub_exp_sum_kernel(int total_size, const dtype* in_data, @@ -129,6 +224,44 @@ __global__ void softmax_sub_exp_sum_kernel(int total_size, } } +template <> +__global__ void softmax_sub_exp_sum_kernel(int total_size, + const half* in_data, + half* out_data, + const half* max_data, + half* sum_data, + int inner_num, + int outer_num, + int axis_size) { + //! compute data index + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < total_size) { + int idx_inner = idx % inner_num; + int idx_outer = (idx / inner_num) * axis_size; + + half max_data_cur = max_data[idx]; + half sum_data_cur = 0; + int real_index = idx_outer * inner_num + idx_inner; + //! compute exp and summarize across the softmax axis + for (int i = 0; i < axis_size; ++i) { +#if __CUDA_ARCH__ >= 530 + half sub_data = __hsub(in_data[real_index], max_data_cur); + sub_data = hexp(sub_data); + sum_data_cur = __hadd(sum_data_cur, sub_data); +#else + half sub_data = __float2half(__half2float(in_data[real_index]) - + __half2float(max_data_cur)); + sub_data = __float2half(expf(__half2float(sub_data))); + sum_data_cur = + __float2half(__half2float(sum_data_cur) + __half2float(sub_data)); +#endif + out_data[real_index] = sub_data; + real_index += inner_num; + } + sum_data[idx] = sum_data_cur; + } +} + template __global__ void softmax_divid_output_kernel(int total_size, dtype* io_data, @@ -151,75 +284,122 @@ __global__ void softmax_divid_output_kernel(int total_size, } } -void SoftmaxCompute::PrepareForRun() { +template <> +__global__ void softmax_divid_output_kernel(int total_size, + half* io_data, + const half* sum_data, + int inner_num, + int outer_num, + int axis_size) { + //! compute data index + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < total_size) { + int idx_inner = idx % inner_num; + int idx_outer = (idx / inner_num) * axis_size; +#if __CUDA_ARCH__ >= 530 + half sum_data_cur = __hdiv(__float2half(1.f), sum_data[idx]); +#else + half sum_data_cur = __float2half(1.f / __half2float(sum_data[idx])); +#endif + int real_index = idx_outer * inner_num + idx_inner; + //! compute final result + for (int i = 0; i < axis_size; ++i) { +#if __CUDA_ARCH__ >= 530 + io_data[real_index] = __hmul(io_data[real_index], sum_data_cur); +#else + io_data[real_index] = __float2half(__half2float(io_data[real_index]) * + __half2float(sum_data_cur)); +#endif + real_index += inner_num; + } + } +} + +template +void SoftmaxCompute::PrepareForRun() { + auto& param = this->template Param(); + auto& ctx = this->ctx_->template As(); int device_id; cudaGetDevice(&device_id); cudaDeviceProp deviceProp; cudaGetDeviceProperties(&deviceProp, device_id); sharedmem_size_ = deviceProp.sharedMemPerBlock; max_dimsize_ = sharedmem_size_ / sizeof(float) / CUDA_NUM_THREADS; + if (param.use_cudnn) { + cudnn_softmax_.Init(param, &ctx); + } } -void SoftmaxCompute::Run() { - auto& param = this->Param(); +template +void SoftmaxCompute::Run() { + auto& param = this->template Param(); auto& ctx = this->ctx_->template As(); auto stream = ctx.exec_stream(); - - auto x_dims = param.x->dims(); - auto x_rank = x_dims.size(); - int axis = param.axis; - if (axis < 0) { - axis += x_rank; - } - int outer_num = x_dims.Slice(0, axis).production(); - int inner_num = x_dims.Slice(axis + 1, x_rank).production(); - int total_threads = inner_num * outer_num; - axis_size_ = x_dims[axis]; - - const int threads = CUDA_NUM_THREADS; - const int blocks = (total_threads + threads - 1) / threads; - auto input_data = param.x->data(); - auto output_data = param.output->mutable_data(TARGET(kCUDA)); - if (axis_size_ <= max_dimsize_) { - int use_sharemem_size = axis_size_ * threads * sizeof(float); - sharemem_softmax_kernel<<>>( - total_threads, - input_data, - output_data, - inner_num, - outer_num, - axis_size_); + if (param.use_cudnn) { + cudnn_softmax_.Create(param, &ctx); + cudnn_softmax_.Run(param); } else { - //! re_alloc device memory - tmax_data_.Resize({1, 1, 1, outer_num * inner_num}); - tsum_data_.Resize({1, 1, 1, outer_num * inner_num}); - auto max_data = tmax_data_.mutable_data(TARGET(kCUDA)); - auto sum_data = tsum_data_.mutable_data(TARGET(kCUDA)); - //! firstly, get maximum data - float min_data = std::numeric_limits::lowest(); - softmax_max_kernel<<>>(total_threads, - input_data, - max_data, - min_data, - inner_num, - outer_num, - axis_size_); - //! then, compute exp and sum data - softmax_sub_exp_sum_kernel<<>>( - total_threads, - input_data, - output_data, - max_data, - sum_data, - inner_num, - outer_num, - axis_size_); - //! last, compute divided output - softmax_divid_output_kernel<<>>( - total_threads, output_data, sum_data, inner_num, outer_num, axis_size_); + auto x_dims = param.x->dims(); + auto x_rank = x_dims.size(); + int axis = param.axis; + if (axis < 0) { + axis += x_rank; + } + int outer_num = x_dims.Slice(0, axis).production(); + int inner_num = x_dims.Slice(axis + 1, x_rank).production(); + int total_threads = inner_num * outer_num; + axis_size_ = x_dims[axis]; + + const int threads = CUDA_NUM_THREADS; + const int blocks = (total_threads + threads - 1) / threads; + auto input_data = param.x->template data(); + auto output_data = + param.output->template mutable_data(TARGET(kCUDA)); + if (axis_size_ <= max_dimsize_) { + int use_sharemem_size = axis_size_ * threads * sizeof(Dtype); + sharemem_softmax_kernel< + Dtype><<>>(total_threads, + input_data, + output_data, + inner_num, + outer_num, + axis_size_); + } else { + //! re_alloc device memory + tmax_data_.Resize({1, 1, 1, outer_num * inner_num}); + tsum_data_.Resize({1, 1, 1, outer_num * inner_num}); + auto max_data = tmax_data_.mutable_data(TARGET(kCUDA)); + auto sum_data = tsum_data_.mutable_data(TARGET(kCUDA)); + //! firstly, get maximum data + float min_data = std::numeric_limits::lowest(); + softmax_max_kernel<<>>(total_threads, + input_data, + max_data, + min_data, + inner_num, + outer_num, + axis_size_); + //! then, compute exp and sum data + softmax_sub_exp_sum_kernel<<>>( + total_threads, + input_data, + output_data, + max_data, + sum_data, + inner_num, + outer_num, + axis_size_); + //! last, compute divided output + softmax_divid_output_kernel<<>>( + total_threads, + output_data, + sum_data, + inner_num, + outer_num, + axis_size_); + } } - cudaError_t error = cudaGetLastError(); - if (error != cudaSuccess) LOG(ERROR) << cudaGetErrorString(error); + CUDA_POST_KERNEL_CHECK; } } // namespace cuda @@ -227,12 +407,12 @@ void SoftmaxCompute::Run() { } // namespace lite } // namespace paddle -REGISTER_LITE_KERNEL(softmax, - kCUDA, - kFloat, - kNCHW, - paddle::lite::kernels::cuda::SoftmaxCompute, - def) +using SoftmaxFp32 = + paddle::lite::kernels::cuda::SoftmaxCompute; +using SoftmaxFp16 = + paddle::lite::kernels::cuda::SoftmaxCompute; + +REGISTER_LITE_KERNEL(softmax, kCUDA, kFloat, kNCHW, SoftmaxFp32, def) .BindInput("X", {LiteType::GetTensorTy(TARGET(kCUDA), PRECISION(kFloat), @@ -246,12 +426,19 @@ REGISTER_LITE_KERNEL(softmax, PRECISION(kFloat), DATALAYOUT(kNCHW))}) .Finalize(); -REGISTER_LITE_KERNEL(search_seq_softmax, - kCUDA, - kFloat, - kNCHW, - paddle::lite::kernels::cuda::SoftmaxCompute, - def) +REGISTER_LITE_KERNEL(softmax, kCUDA, kFP16, kNCHW, SoftmaxFp16, def) + .BindInput("X", + {LiteType::GetTensorTy(TARGET(kCUDA), + PRECISION(kFP16), + DATALAYOUT(kNCHW))}) + .BindOutput("Out", + {LiteType::GetTensorTy(TARGET(kCUDA), + PRECISION(kFP16), + DATALAYOUT(kNCHW))}) + .BindOutput("Out_log", + {LiteType::GetTensorTy(TARGET(kCUDA), PRECISION(kFP16))}) + .Finalize(); +REGISTER_LITE_KERNEL(search_seq_softmax, kCUDA, kFloat, kNCHW, SoftmaxFp32, def) .BindInput("X", {LiteType::GetTensorTy(TARGET(kCUDA), PRECISION(kFloat), @@ -262,3 +449,15 @@ REGISTER_LITE_KERNEL(search_seq_softmax, DATALAYOUT(kNCHW))}) .BindOutput("Out_log", {LiteType::GetTensorTy(TARGET(kCUDA))}) .Finalize(); +REGISTER_LITE_KERNEL(search_seq_softmax, kCUDA, kFP16, kNCHW, SoftmaxFp16, def) + .BindInput("X", + {LiteType::GetTensorTy(TARGET(kCUDA), + PRECISION(kFP16), + DATALAYOUT(kNCHW))}) + .BindOutput("Out", + {LiteType::GetTensorTy(TARGET(kCUDA), + PRECISION(kFP16), + DATALAYOUT(kNCHW))}) + .BindOutput("Out_log", + {LiteType::GetTensorTy(TARGET(kCUDA), PRECISION(kFP16))}) + .Finalize(); diff --git a/lite/kernels/cuda/softmax_compute.h b/lite/kernels/cuda/softmax_compute.h index e563b36178..57fe2890f6 100644 --- a/lite/kernels/cuda/softmax_compute.h +++ b/lite/kernels/cuda/softmax_compute.h @@ -13,6 +13,7 @@ // limitations under the License. #pragma once +#include "lite/backends/cuda/math/cudnn_softmax.h" #include "lite/core/kernel.h" namespace paddle { @@ -20,8 +21,9 @@ namespace lite { namespace kernels { namespace cuda { +template class SoftmaxCompute - : public KernelLite { + : public KernelLite { public: using param_t = operators::SoftmaxParam; @@ -30,6 +32,7 @@ class SoftmaxCompute virtual ~SoftmaxCompute() = default; private: + lite::cuda::math::CudnnSoftmax cudnn_softmax_; lite::Tensor tmax_data_; lite::Tensor tsum_data_; size_t sharedmem_size_; diff --git a/lite/kernels/cuda/softmax_compute_test.cc b/lite/kernels/cuda/softmax_compute_test.cc index b4d5352091..404d8e8513 100644 --- a/lite/kernels/cuda/softmax_compute_test.cc +++ b/lite/kernels/cuda/softmax_compute_test.cc @@ -20,114 +20,192 @@ #include #include +#include "lite/api/test_helper.h" +#include "lite/utils/float16.h" + namespace paddle { namespace lite { namespace kernels { namespace cuda { -using Tensor = lite::Tensor; -using DDim = lite::DDim; - -template -static void softmax_compute_ref(const operators::SoftmaxParam& param) { - const dtype* x_data = param.x->mutable_data(); - dtype* output_data = param.output->mutable_data(); - DDim x_dims = param.x->dims(); - ASSERT_EQ(x_dims.data(), param.output->dims().data()); - auto x_rank = x_dims.size(); - int axis = param.axis; - if (axis < 0) { - axis += x_rank; +class SoftmaxTest : public ::testing::Test { + protected: + SoftmaxTest() + : n_(2), + c_(2), + h_(2), + w_(2), + axis_(1), + use_cudnn_(true), + shape_({n_, c_, h_, w_}) { + x_ref_.Resize(lite::DDim(shape_)); + x_gpu_.Resize(lite::DDim(shape_)); + + auto x_ref_data = x_ref_.mutable_data(); + + for (int64_t i = 0; i < x_ref_.numel(); i++) { + x_ref_data[i] = static_cast(i % 10 * 0.2); + } + + out_ref_.Resize(lite::DDim(shape_)); + out_cpu_.Resize(out_ref_.dims()); + out_gpu_.Resize(out_ref_.dims()); + RunBaseLine(); + + InitParamAndContext(); + } + + void InitParamAndContext() { + ctx_.reset(new KernelContext); + cudaStreamCreate(&stream_); + auto& context = ctx_->As(); + context.SetExecStream(stream_); + param_.x = &x_gpu_; + param_.axis = axis_; + param_.output = &out_gpu_; + param_.use_cudnn = use_cudnn_; + } + + void InitFloatInput() { + x_gpu_.Assign(x_ref_.data(), + x_gpu_.dims()); } - int axis_size = x_dims[axis]; - int outer_num = x_dims.Slice(0, axis).production(); - int inner_num = x_dims.Slice(axis + 1, x_rank).production(); - int compute_size = outer_num * inner_num; - for (int i = 0; i < compute_size; i++) { - int idx_inner = i % inner_num; - int idx_outer = (i / inner_num) * axis_size; - int start = idx_outer * inner_num + idx_inner; - int offset; - - offset = start; - dtype max_data = std::numeric_limits::lowest(); - for (int j = 0; j < axis_size; j++) { - max_data = x_data[offset] > max_data ? x_data[offset] : max_data; - offset += inner_num; + + void InitHalfInput() { + x_half_.Resize(x_ref_.dims()); + auto x_half_data = x_half_.mutable_data(); + for (int64_t i = 0; i < x_half_.numel(); i++) { + x_half_data[i] = half(lite::float16(x_ref_.data()[i])); } + x_gpu_.Assign(x_half_data, x_gpu_.dims()); + } - offset = start; - dtype sum_data = (dtype)0; - for (int j = 0; j < axis_size; j++) { - output_data[offset] = exp(x_data[offset] - max_data); - sum_data += output_data[offset]; - offset += inner_num; + void RunBaseLine() { + const float* x_data = x_ref_.mutable_data(); + float* output_data = out_ref_.mutable_data(); + DDim x_dims = x_ref_.dims(); + ASSERT_EQ(x_dims.data(), out_ref_.dims().data()); + auto x_rank = x_dims.size(); + int axis = axis_; + if (axis < 0) { + axis += x_rank; } + int axis_size = x_dims[axis]; + int outer_num = x_dims.Slice(0, axis).production(); + int inner_num = x_dims.Slice(axis + 1, x_rank).production(); + int compute_size = outer_num * inner_num; + for (int i = 0; i < compute_size; i++) { + int idx_inner = i % inner_num; + int idx_outer = (i / inner_num) * axis_size; + int start = idx_outer * inner_num + idx_inner; + int offset; + + offset = start; + float max_data = std::numeric_limits::lowest(); + for (int j = 0; j < axis_size; j++) { + max_data = x_data[offset] > max_data ? x_data[offset] : max_data; + offset += inner_num; + } - offset = start; - for (int j = 0; j < axis_size; j++) { - output_data[offset] /= sum_data; - offset += inner_num; + offset = start; + float sum_data = 0.f; + for (int j = 0; j < axis_size; j++) { + output_data[offset] = exp(x_data[offset] - max_data); + sum_data += output_data[offset]; + offset += inner_num; + } + + offset = start; + for (int j = 0; j < axis_size; j++) { + output_data[offset] /= sum_data; + offset += inner_num; + } } } + + int n_, c_, h_, w_, axis_; + bool use_cudnn_; + std::vector shape_; + lite::Tensor x_ref_, out_ref_; + lite::Tensor x_gpu_; + lite::Tensor x_half_; + lite::Tensor out_cpu_, out_gpu_; + + operators::SoftmaxParam param_; + std::unique_ptr ctx_; + cudaStream_t stream_; +}; + +TEST_F(SoftmaxTest, TestFP32) { + InitFloatInput(); + SoftmaxCompute kernel; + kernel.SetParam(param_); + kernel.SetContext(std::move(ctx_)); + + for (int i = 0; i < FLAGS_warmup; ++i) { + kernel.Launch(); + cudaDeviceSynchronize(); + } + + auto start = GetCurrentUS(); + kernel.PrepareForRun(); + for (int i = 0; i < FLAGS_repeats; ++i) { + kernel.Run(); + } + cudaDeviceSynchronize(); + auto duration = (GetCurrentUS() - start) / 1000.0; + LOG(INFO) << "fp32, warmup: " << FLAGS_warmup + << ", repeats: " << FLAGS_repeats << ", spend " + << duration / FLAGS_repeats << " ms in average."; + + CopySync(out_cpu_.mutable_data(), + out_gpu_.data(), + sizeof(float) * out_gpu_.numel(), + IoDirection::DtoH); + + for (int i = 0; i < out_gpu_.numel(); ++i) { + float res = out_cpu_.data()[i]; + float ref = out_ref_.data()[i]; + EXPECT_NEAR(fabs(res - ref) / ref, 0.f, 1e-5); + } } -TEST(softmax_cuda, compute) { - std::unique_ptr ctx(new KernelContext); - auto& context = ctx->As(); - cudaStream_t stream; - cudaStreamCreate(&stream); - context.SetExecStream(stream); - - SoftmaxCompute softmax; - operators::SoftmaxParam param; - softmax.SetContext(std::move(ctx)); - lite::Tensor x; - lite::Tensor x_cpu; - lite::Tensor output; - lite::Tensor output_cpu; - lite::Tensor output_ref; - for (auto n : {1, 3}) { - for (auto c : {1, 4}) { - for (auto h : {5, 1, 112}) { - for (auto w : {1, 6, 112}) { - for (auto axis : {-2, -1, 0, 1, 2}) { - x.Resize({n, c, h, w}); - x_cpu.Resize({n, c, h, w}); - output.Resize({n, c, h, w}); - output_cpu.Resize({n, c, h, w}); - output_ref.Resize({n, c, h, w}); - auto* x_cpu_data = x_cpu.mutable_data(); - auto* output_data = output.mutable_data(TARGET(kCUDA)); - auto* output_cpu_data = output_ref.mutable_data(); - auto* output_ref_data = output_ref.mutable_data(); - for (int i = 0; i < x.dims().production(); i++) { - x_cpu_data[i] = i; - } - x.Assign(x_cpu_data, - x_cpu.dims()); - param.x = &x; - param.axis = axis; - param.output = &output; - softmax.SetParam(param); - softmax.Launch(); - param.x = &x_cpu; - param.output = &output_ref; - softmax_compute_ref(param); - cudaDeviceSynchronize(); - CopySync(output_cpu_data, - output_data, - sizeof(float) * output.numel(), - IoDirection::DtoH); - for (int i = 0; i < output.dims().production(); i++) { - EXPECT_NEAR(output_cpu_data[i], output_ref_data[i], 1e-5); - } - } - } - } - } +TEST_F(SoftmaxTest, TestFP16) { + InitHalfInput(); + SoftmaxCompute kernel; + kernel.SetParam(param_); + kernel.SetContext(std::move(ctx_)); + + for (int i = 0; i < FLAGS_warmup; ++i) { + kernel.Launch(); + cudaDeviceSynchronize(); + } + + auto start = GetCurrentUS(); + kernel.PrepareForRun(); + for (int i = 0; i < FLAGS_repeats; ++i) { + kernel.Run(); + } + cudaDeviceSynchronize(); + auto duration = (GetCurrentUS() - start) / 1000.0; + LOG(INFO) << "fp16, warmup: " << FLAGS_warmup + << ", repeats: " << FLAGS_repeats << ", spend " + << duration / FLAGS_repeats << " ms in average."; + + const half* out_gpu_data = out_gpu_.data(); + half* out_cpu_data = out_cpu_.mutable_data(); + CopySync(out_cpu_data, + out_gpu_data, + sizeof(half) * out_gpu_.numel(), + IoDirection::DtoH); + + for (int i = 0; i < out_gpu_.numel(); ++i) { + float res = static_cast(lite::float16(out_cpu_data[i])); + float ref = out_ref_.data()[i]; + EXPECT_NEAR(fabs(res - ref) / (ref + 1e-5), 0., 1e-2); } } + } // namespace cuda } // namespace kernels } // namespace lite diff --git a/lite/operators/op_params.h b/lite/operators/op_params.h index 11a8943832..0e14e82b80 100644 --- a/lite/operators/op_params.h +++ b/lite/operators/op_params.h @@ -268,6 +268,7 @@ struct SoftmaxParam : ParamBase { lite::Tensor* x{}; lite::Tensor* output{}; int axis{-1}; + bool use_cudnn{true}; /////////////////////////////////////////////////////////////////////////////////// // get a vector of input tensors const std::vector* input_tensor_ptrs() override { diff --git a/lite/operators/softmax_op.cc b/lite/operators/softmax_op.cc index e95e355bda..4e4aa9c981 100644 --- a/lite/operators/softmax_op.cc +++ b/lite/operators/softmax_op.cc @@ -52,6 +52,9 @@ bool SoftmaxOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) { } CHECK(param_.x); CHECK(param_.output); + if (opdesc.HasAttr("use_cudnn")) { + param_.use_cudnn = opdesc.GetAttr("use_cudnn"); + } return true; } -- GitLab