From 18aca3f5691a5b7c4c4a4f213f4923f36ffdecfc Mon Sep 17 00:00:00 2001 From: Feng Xing <79969986+xingfeng01@users.noreply.github.com> Date: Thu, 9 Dec 2021 11:16:55 +0800 Subject: [PATCH] format softmax forward (#37927) --- paddle/fluid/operators/softmax_cudnn_op.cu.h | 32 ++++++++++++-------- 1 file changed, 20 insertions(+), 12 deletions(-) diff --git a/paddle/fluid/operators/softmax_cudnn_op.cu.h b/paddle/fluid/operators/softmax_cudnn_op.cu.h index 533488896df..0c10152c23b 100644 --- a/paddle/fluid/operators/softmax_cudnn_op.cu.h +++ b/paddle/fluid/operators/softmax_cudnn_op.cu.h @@ -222,15 +222,27 @@ __global__ void WarpSoftmaxForward(T* softmax, const T* src, idx_max_v[i] = idx_max / kVSize; } - // read data from global memory + // data src AccT srcdata[kBatchSize][kLoopsV][kVSize]; - kps::Init(&srcdata[0][0][0], kLowInf); T src_tmp[kBatchSize][kLoopsV][kVSize]; + kps::Init(&srcdata[0][0][0], kLowInf); kps::Init(&src_tmp[0][0][0], -std::numeric_limits::infinity()); + + // data dst + T out_tmp[kBatchSize][kLoopsV][kVSize]; + + // max value + AccT max[kBatchSize]; + kps::Init(&max[0], kLowInf); + + // sum value + AccT sum[kBatchSize] = {0}; + +// read data from global memory #pragma unroll for (int i = 0; i < kBatchSize; ++i) { - int ptr = (first_batch + i) * stride; - const VecT* src_v = reinterpret_cast(&src[ptr]); + const VecT* src_v = + reinterpret_cast(&src[(first_batch + i) * stride]); VecT* reg_v = reinterpret_cast(&src_tmp[i][0][0]); kps::ReadData( ®_v[0], &src_v[0], idx_max_v[i], 0, kWarpSize, 1); @@ -239,15 +251,12 @@ __global__ void WarpSoftmaxForward(T* softmax, const T* src, } // compute max - AccT max[kBatchSize]; - kps::Init(&max[0], kLowInf); kps::Reduce, kMode::kLocalMode>(&max[0], &srcdata[0][0][0], ReduceMaxFunctor(), true); WarpReduceMax(max); // compute sum - AccT sum[kBatchSize] = {0}; for (int i = 0; i < kBatchSize; ++i) { kps::ElementwiseUnary>( &srcdata[i][0][0], &srcdata[i][0][0], ExpSubFunctor(max[i])); @@ -257,15 +266,14 @@ __global__ void WarpSoftmaxForward(T* softmax, const T* src, kps::AddFunctor(), true); WarpReduceSum(sum); - // write result to global memory - T out_tmp[kBatchSize][kLoopsV][kVSize]; +// write data to global memory #pragma unroll for (int i = 0; i < kBatchSize; ++i) { + VecT* softmax_v = + reinterpret_cast(&softmax[(first_batch + i) * stride]); + VecT* reg_v = reinterpret_cast(&out_tmp[i][0][0]); kps::ElementwiseUnary>( &out_tmp[i][0][0], &srcdata[i][0][0], UnaryDivFunctor(sum[i])); - int softmax_ptr = (first_batch + i) * stride; - VecT* softmax_v = reinterpret_cast(&softmax[softmax_ptr]); - VecT* reg_v = reinterpret_cast(&out_tmp[i][0][0]); kps::WriteData( &softmax_v[0], ®_v[0], idx_max_v[i], 0, kWarpSize, 1); } -- GitLab