From 2ea15fc9618b82f8003ea6df57a6446fc117282a Mon Sep 17 00:00:00 2001 From: Zhang Zheng <32410583+ZzSean@users.noreply.github.com> Date: Fri, 11 Feb 2022 15:53:18 +0800 Subject: [PATCH] Optimize performance of softmax_bwd when axis!=-1 (#38609) * Optimize performance of softmax_bwd when axis!=-1 * fix * fix * fix * fix --- paddle/fluid/operators/softmax_cudnn_op.cu.h | 62 ++++++++++++++++++++ 1 file changed, 62 insertions(+) diff --git a/paddle/fluid/operators/softmax_cudnn_op.cu.h b/paddle/fluid/operators/softmax_cudnn_op.cu.h index 85bc3946bae..dc5166f4f99 100644 --- a/paddle/fluid/operators/softmax_cudnn_op.cu.h +++ b/paddle/fluid/operators/softmax_cudnn_op.cu.h @@ -584,6 +584,43 @@ __global__ void NormalSoftmaxForward(T* output, const T* input, int high_dim, } } +template class Functor> +__global__ void NormalSoftmaxBackward(T* input_grad, const T* output_grad, + const T* output, int high_dim, + int mid_dim, int low_dim) { + using kMode = kps::details::ReduceMode; + const int high_stride = mid_dim * low_dim; + const int mid_stride = low_dim; + for (int high_id = blockIdx.y; high_id < high_dim; high_id += gridDim.y) { + for (int low_id = blockIdx.x * blockDim.x + threadIdx.x; low_id < low_dim; + low_id += blockDim.x * gridDim.x) { + const int grad_offset = high_id * high_stride + low_id; + + // 1. reduce sum + AccT sum = 0; + for (int mid_id = threadIdx.y; mid_id < mid_dim; mid_id += blockDim.y) { + int data_offset = grad_offset + mid_id * mid_stride; + sum += static_cast(output_grad[data_offset]) * + static_cast(output[data_offset]); + } + if (blockDim.y > 1) { + kps::Reduce, kMode::kGlobalMode>( + &sum, &sum, kps::AddFunctor(), false); + } + + // 2. (log)softmax backward + Functor functor(sum); + for (int mid_id = threadIdx.y; mid_id < mid_dim; mid_id += blockDim.y) { + int data_offset = grad_offset + mid_id * mid_stride; + input_grad[data_offset] = + functor(static_cast(output_grad[data_offset]), + static_cast(output[data_offset])); + } + } + } +} + template void LaunchNormalSoftmaxForward(const platform::CUDADeviceContext& dev_ctx, T* output_data, const T* input_data, @@ -603,6 +640,28 @@ void LaunchNormalSoftmaxForward(const platform::CUDADeviceContext& dev_ctx, } } +template +void LaunchNormalSoftmaxBackward(const platform::CUDADeviceContext& dev_ctx, + T* input_grad_data, const T* output_grad_data, + const T* output_data, int high_dim, + int mid_dim, int low_dim) { + using AccT = typename details::MPTypeTrait::Type; + dim3 grid, block; + GetLaunchConfig(high_dim, mid_dim, low_dim, &grid, &block); + if (LogMode) { + NormalSoftmaxBackward< + T, AccT, + LogSoftmaxBackwardFunctor><<>>( + input_grad_data, output_grad_data, output_data, high_dim, mid_dim, + low_dim); + } else { + NormalSoftmaxBackward< + T, AccT, SoftmaxBackwardFunctor><<>>( + input_grad_data, output_grad_data, output_data, high_dim, mid_dim, + low_dim); + } +} + template void SoftmaxForwardCUDAKernelDriver(const platform::CUDADeviceContext& dev_ctx, const Tensor& x, const int input_axis, @@ -741,6 +800,9 @@ void SoftmaxBackwardCUDAKernelDriver(const platform::CUDADeviceContext& dev_ctx, blocks, threads, dev_ctx, dx_data, dout.data(), out.data(), N, dim, dim, kDimLog2); } + } else if (D > 1) { + LaunchNormalSoftmaxBackward(dev_ctx, dx_data, dout.data(), + out.data(), N, dim, D); } else { ScopedTensorDescriptor desc; std::vector tensor_dims = {N, dim, D, 1}; -- GitLab