From 2dec25dba21da5eb79bb364cb5eb58ece561a433 Mon Sep 17 00:00:00 2001 From: Zhang Zheng <32410583+ZzSean@users.noreply.github.com> Date: Wed, 16 Mar 2022 22:10:13 +0800 Subject: [PATCH] Optimize the computation of log_softmax (#40612) * Optimize the computation of log_softmax * modify the var name --- paddle/phi/kernels/gpudnn/softmax_gpudnn.h | 43 ++++++++++------------ 1 file changed, 20 insertions(+), 23 deletions(-) diff --git a/paddle/phi/kernels/gpudnn/softmax_gpudnn.h b/paddle/phi/kernels/gpudnn/softmax_gpudnn.h index 2b2dd51189..77159bfc87 100644 --- a/paddle/phi/kernels/gpudnn/softmax_gpudnn.h +++ b/paddle/phi/kernels/gpudnn/softmax_gpudnn.h @@ -121,17 +121,10 @@ struct ReduceMaxFunctor { }; template -struct ExpSubFunctor { - HOSTDEVICE inline ExpSubFunctor() { y = static_cast(0.0f); } - - HOSTDEVICE explicit inline ExpSubFunctor(Tx y) : y((Tx)(y)) {} - +struct ExpFunctor { HOSTDEVICE inline Ty operator()(const Tx& x) const { - return static_cast(std::exp(x - y)); + return static_cast(std::exp(x)); } - - private: - Tx y; }; template @@ -293,10 +286,14 @@ __global__ void WarpSoftmaxForward(T* softmax, } // data src - AccT srcdata[kBatchSize][kLoopsV][kVSize]; - T src_tmp[kBatchSize][kLoopsV][kVSize]; - kps::Init(&srcdata[0][0][0], kLowInf); - kps::Init(&src_tmp[0][0][0], -std::numeric_limits::infinity()); + // src_data: the raw data form global memory + // sub_data: store the data obtained by (src_data - max), used by log_softmax + // exp_data: store the data obtained by (exp(sub_data)), used by softmax + T src_data[kBatchSize][kLoopsV][kVSize]; + AccT sub_data[kBatchSize][kLoopsV][kVSize]; + AccT exp_data[kBatchSize][kLoopsV][kVSize]; + kps::Init(&sub_data[0][0][0], kLowInf); + kps::Init(&src_data[0][0][0], -std::numeric_limits::infinity()); // data dst T out_tmp[kBatchSize][kLoopsV][kVSize]; @@ -313,11 +310,11 @@ __global__ void WarpSoftmaxForward(T* softmax, for (int i = 0; i < kBatchSize; ++i) { const VecT* src_v = reinterpret_cast(&src[(first_batch + i) * stride]); - VecT* reg_v = reinterpret_cast(&src_tmp[i][0][0]); + VecT* reg_v = reinterpret_cast(&src_data[i][0][0]); kps::ReadData( ®_v[0], &src_v[0], idx_max_v[i], 0, kWarpSize, 1); kps::ElementwiseUnary>( - &srcdata[i][0][0], &src_tmp[i][0][0], DataTransFunctor()); + &sub_data[i][0][0], &src_data[i][0][0], DataTransFunctor()); } // compute max @@ -327,14 +324,16 @@ __global__ void WarpSoftmaxForward(T* softmax, 1, ReduceMaxFunctor, kMode::kLocalMode>( - &max[0], &srcdata[0][0][0], ReduceMaxFunctor(), true); + &max[0], &sub_data[0][0][0], ReduceMaxFunctor(), true); WarpReduceMax(max); // compute sum #pragma unroll for (int i = 0; i < kBatchSize; ++i) { - kps::ElementwiseUnary>( - &srcdata[i][0][0], &srcdata[i][0][0], ExpSubFunctor(max[i])); + kps::ElementwiseUnary>( + &sub_data[i][0][0], &sub_data[i][0][0], UnarySubFunctor(max[i])); + kps::ElementwiseUnary>( + &exp_data[i][0][0], &sub_data[i][0][0], ExpFunctor()); } kps::Reduce, kMode::kLocalMode>( - &sum[0], &srcdata[0][0][0], kps::AddFunctor(), true); + &sum[0], &exp_data[0][0][0], kps::AddFunctor(), true); WarpReduceSum(sum); // write data to global memory @@ -352,15 +351,13 @@ __global__ void WarpSoftmaxForward(T* softmax, reinterpret_cast(&softmax[(first_batch + i) * stride]); VecT* reg_v = reinterpret_cast(&out_tmp[i][0][0]); if (LogMode) { - kps::ElementwiseUnary>( - &srcdata[i][0][0], &srcdata[i][0][0], UnaryLogFunctor()); kps::ElementwiseUnary>( &out_tmp[i][0][0], - &srcdata[i][0][0], + &sub_data[i][0][0], UnarySubFunctor(std::log(sum[i]))); } else { kps::ElementwiseUnary>( - &out_tmp[i][0][0], &srcdata[i][0][0], UnaryDivFunctor(sum[i])); + &out_tmp[i][0][0], &exp_data[i][0][0], UnaryDivFunctor(sum[i])); } kps::WriteData( &softmax_v[0], ®_v[0], idx_max_v[i], 0, kWarpSize, 1); -- GitLab