diff --git a/paddle/fluid/operators/softmax_cudnn_op.cu.h b/paddle/fluid/operators/softmax_cudnn_op.cu.h index 533488896dfcd177edfeaa5cd49f2cf36f7881a9..0c10152c23b2ae3809dca7503791775225905b4d 100644 --- a/paddle/fluid/operators/softmax_cudnn_op.cu.h +++ b/paddle/fluid/operators/softmax_cudnn_op.cu.h @@ -222,15 +222,27 @@ __global__ void WarpSoftmaxForward(T* softmax, const T* src, idx_max_v[i] = idx_max / kVSize; } - // read data from global memory + // data src AccT srcdata[kBatchSize][kLoopsV][kVSize]; - kps::Init(&srcdata[0][0][0], kLowInf); T src_tmp[kBatchSize][kLoopsV][kVSize]; + kps::Init(&srcdata[0][0][0], kLowInf); kps::Init(&src_tmp[0][0][0], -std::numeric_limits::infinity()); + + // data dst + T out_tmp[kBatchSize][kLoopsV][kVSize]; + + // max value + AccT max[kBatchSize]; + kps::Init(&max[0], kLowInf); + + // sum value + AccT sum[kBatchSize] = {0}; + +// read data from global memory #pragma unroll for (int i = 0; i < kBatchSize; ++i) { - int ptr = (first_batch + i) * stride; - const VecT* src_v = reinterpret_cast(&src[ptr]); + const VecT* src_v = + reinterpret_cast(&src[(first_batch + i) * stride]); VecT* reg_v = reinterpret_cast(&src_tmp[i][0][0]); kps::ReadData( ®_v[0], &src_v[0], idx_max_v[i], 0, kWarpSize, 1); @@ -239,15 +251,12 @@ __global__ void WarpSoftmaxForward(T* softmax, const T* src, } // compute max - AccT max[kBatchSize]; - kps::Init(&max[0], kLowInf); kps::Reduce, kMode::kLocalMode>(&max[0], &srcdata[0][0][0], ReduceMaxFunctor(), true); WarpReduceMax(max); // compute sum - AccT sum[kBatchSize] = {0}; for (int i = 0; i < kBatchSize; ++i) { kps::ElementwiseUnary>( &srcdata[i][0][0], &srcdata[i][0][0], ExpSubFunctor(max[i])); @@ -257,15 +266,14 @@ __global__ void WarpSoftmaxForward(T* softmax, const T* src, kps::AddFunctor(), true); WarpReduceSum(sum); - // write result to global memory - T out_tmp[kBatchSize][kLoopsV][kVSize]; +// write data to global memory #pragma unroll for (int i = 0; i < kBatchSize; ++i) { + VecT* softmax_v = + reinterpret_cast(&softmax[(first_batch + i) * stride]); + VecT* reg_v = reinterpret_cast(&out_tmp[i][0][0]); kps::ElementwiseUnary>( &out_tmp[i][0][0], &srcdata[i][0][0], UnaryDivFunctor(sum[i])); - int softmax_ptr = (first_batch + i) * stride; - VecT* softmax_v = reinterpret_cast(&softmax[softmax_ptr]); - VecT* reg_v = reinterpret_cast(&out_tmp[i][0][0]); kps::WriteData( &softmax_v[0], ®_v[0], idx_max_v[i], 0, kWarpSize, 1); }