From 250e254fac69680e07510a3c4f34337656a56a79 Mon Sep 17 00:00:00 2001 From: Zhang Zheng <32410583+ZzSean@users.noreply.github.com> Date: Mon, 14 Mar 2022 16:48:30 +0800 Subject: [PATCH] Optimize performance of log_softmax (#38992) * Optimize performance of log_softmax * delete unity build * modify to phi * fix * fixfixfixfix * fix * fix * fix * fix * simplify * fix * fix enforce --- paddle/fluid/operators/log_softmax_op.cu | 464 ++------------------- paddle/phi/kernels/gpudnn/softmax_gpudnn.h | 122 ++++-- 2 files changed, 124 insertions(+), 462 deletions(-) diff --git a/paddle/fluid/operators/log_softmax_op.cu b/paddle/fluid/operators/log_softmax_op.cu index 8770abdac83..26b6ce43303 100644 --- a/paddle/fluid/operators/log_softmax_op.cu +++ b/paddle/fluid/operators/log_softmax_op.cu @@ -12,459 +12,43 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include #include "paddle/fluid/operators/log_softmax_op.h" -#include "paddle/fluid/platform/device/gpu/gpu_device_function.h" -#include "paddle/phi/common/amp_type_traits.h" -#include "paddle/phi/kernels/funcs/elementwise_functor.h" -#include "paddle/phi/kernels/funcs/functors.h" +#include "paddle/phi/kernels/gpudnn/softmax_gpudnn.h" namespace paddle { namespace operators { -#define LAUNCH_WARP_FORWAR_COMPUTE(near_greater_power_of_two) \ - case near_greater_power_of_two: \ - ComputeLogSoftmaxForwardInWarp< \ - T, AccT, near_greater_power_of_two><<>>( \ - dst, src, outer_size, dim_size); \ - break; - -template -__device__ __forceinline__ T WarpReduceSum(T value) { -#pragma unroll - for (int offset = KernelWarpSize / 2; offset > 0; offset /= 2) { - T sum_val = platform::CudaShuffleXorSync(0xFFFFFFFF, value, offset); - value = value + sum_val; - } - return value; -} - -template -__device__ __forceinline__ T WarpReduceMax(T value) { -#pragma unroll - for (int offset = KernelWarpSize / 2; offset > 0; offset /= 2) { - T max_val = platform::CudaShuffleXorSync(0xFFFFFFFF, value, offset); - value = max(value, max_val); - } - return value; -} - -int GetNearGreaterPowerOfTwo(int value) { - int log2_value = 0; - while ((1 << log2_value) < value) { - ++log2_value; - } - return 1 << log2_value; -} - -template -__global__ void ComputeLogSoftmaxForwardInWarp(T *dst, const T *src, - int batch_size, - int element_count) { - constexpr int near_greater_power_of_two = NearGreaterPowerOfTwo; - constexpr int kernel_warp_size = - (near_greater_power_of_two < 32) ? near_greater_power_of_two : 32; - constexpr int warp_iter = near_greater_power_of_two / kernel_warp_size; - int batch_id = blockDim.y * blockIdx.x + threadIdx.y; - - int thread_in_warp_idx = threadIdx.x; - - // 1.read data from global memory to registers - AccT elements[warp_iter]; - // set effective_element_count as the num of elements when warps do effective - // work - // set effective_element_count as 0, when warps do ineffective work - int effective_element_count = (batch_id < batch_size) ? element_count : 0; - for (int it = 0; it < warp_iter; ++it) { - int element_index = thread_in_warp_idx + it * kernel_warp_size; - if (element_index < effective_element_count) { - elements[it] = - static_cast(src[batch_id * element_count + element_index]); - } else { - elements[it] = -std::numeric_limits::infinity(); - } - } - - // 2.compute max_value. For each thread, loop all registers to find max - AccT max_value = elements[0]; -#pragma unroll - for (int it = 1; it < warp_iter; ++it) { - max_value = (max_value > elements[it]) ? max_value : elements[it]; - } - max_value = WarpReduceMax(max_value); - - // 3.For each warp, accumulate all thread registers - AccT sum = 0.0f; -#pragma unroll - for (int it = 0; it < warp_iter; ++it) { - sum += std::exp(elements[it] - max_value); - } - sum = WarpReduceSum(sum); - - // 4.store result. - sum = std::log(sum); -#pragma unroll - for (int it = 0; it < warp_iter; ++it) { - int element_index = thread_in_warp_idx + it * kernel_warp_size; - if (element_index < effective_element_count) { - dst[batch_id * element_count + element_index] = - static_cast(elements[it] - max_value - sum); - } else { - break; - } - } -} - -template -void LaunchSoftmaxForwardForLastAxis(T *dst, const T *src, int dim_size, - int outer_size, gpuStream_t stream) { - int threads_per_block = 128; - int near_greater_power_of_two = GetNearGreaterPowerOfTwo(dim_size); - int kernel_warp_size = - (near_greater_power_of_two < 32) ? near_greater_power_of_two : 32; - int warps_per_block = (threads_per_block / kernel_warp_size); - int blocks = (outer_size + warps_per_block - 1) / warps_per_block; - dim3 threads(kernel_warp_size, warps_per_block, 1); - - switch (near_greater_power_of_two) { - LAUNCH_WARP_FORWAR_COMPUTE(1); - LAUNCH_WARP_FORWAR_COMPUTE(2); - LAUNCH_WARP_FORWAR_COMPUTE(4); // dim_size: 3~4 - LAUNCH_WARP_FORWAR_COMPUTE(8); // dim_size: 5~8 - LAUNCH_WARP_FORWAR_COMPUTE(16); // dim_size: 9~16 - LAUNCH_WARP_FORWAR_COMPUTE(32); // dim_size: 17~32 - LAUNCH_WARP_FORWAR_COMPUTE(64); // dim_size: 33~64 - LAUNCH_WARP_FORWAR_COMPUTE(128); // dim_size 65~128 - LAUNCH_WARP_FORWAR_COMPUTE(256); // dim_size 129~256 - LAUNCH_WARP_FORWAR_COMPUTE(512); // dim_size 257~512 - LAUNCH_WARP_FORWAR_COMPUTE(1024); // dim_size 513~1024 - - default: - break; - } -} - -// Returns the final item after reduce operation along block.x. -// Firstly, get shared memory(smem) offset, find the starting position for every -// y. -// Secondly, initialise every smem position with value 'val' of thread itself. -// Thirdly, apply standard reduction along x direction as below: -// -// -> x direction -// [o o o o o o o o] time 0 -// | |/ / -// | /| / -// | / | / -// |/ |/ -// [o o o o x x x x] time 1 -// | |/ / -// |/|/ -// [o o x x x x x x] time 2 -// |/ -// [o x x x x x x x] time 3 -// -// Finally, return the first item. -// Imaging multiple reductions executed in paralell along y axis, -// Note that when blockDim.x is not 1, it's a EVEN number in all cases, -// and the size of shared memory is even as well. -template class Functor> -__forceinline__ __device__ T BlockReduceAlongDimX(T *shared, T val) { - Functor func; - // This reduction is not Block-wise reduction, only reduce along block.x. - // therefore the shared mem has offsets for different block.y. - shared += threadIdx.y * blockDim.x; - shared[threadIdx.x] = val; - int offset = blockDim.x / 2; - - while (offset > 0) { - __syncthreads(); - if (threadIdx.x < offset) { - shared[threadIdx.x] = - func(shared[threadIdx.x], shared[threadIdx.x + offset]); - } - offset /= 2; - } - __syncthreads(); - return shared[0]; -} - -template -__global__ void LogSoftmaxForwardCUDAKernelNotLastAxis( - T *output, const T *input, int outer_size, int dim_size, int inner_size) { - extern __shared__ unsigned char smem[]; - auto sdata = reinterpret_cast(smem); - - const int outer_stride = inner_size * dim_size; - const int dim_stride = inner_size; - - for (int x_id = blockIdx.x; x_id < outer_size; x_id += gridDim.x) { - for (int y_id = blockIdx.y * blockDim.y + threadIdx.y; y_id < inner_size; - y_id += blockDim.y * gridDim.y) { - const int data_offset = x_id * outer_stride + y_id; - // When blockDim.x==1, no block.x-reduction opetaions are needed. - // And threadIdx.x is 0 all the time, so the for-loops below are literally - // loops (No parallel executions). Loop all elements along axis and - // calculate the Max, Sum and (input[id]-Max-log(Sum)) to get the final - // log_softmax values along that axis. - // 1. reduce max - AccT max_value = -std::numeric_limits::infinity(); - // For one thread, iterate all items it responsable for, and get - // max_value. - // If there are N threads, N max_value will be returned. - for (int d = threadIdx.x; d < dim_size; d += blockDim.x) { - const AccT value = - static_cast(input[data_offset + d * dim_stride]); - max_value = phi::funcs::MaxFunctor()(max_value, value); - } - // If there are more than 1 threads along block x, reduce all max_values - // and get the global max_value, which is the max value along "axis". - // If there is only one thread along block x, no need to reduce, as the - // 'max_value' is the global max_value. - if (blockDim.x > 1) { - max_value = BlockReduceAlongDimX( - sdata, max_value); - } - - // 2. reduce sum - AccT sum = 0; - // Below is the same execution as '1. reduce max' - for (int d = threadIdx.x; d < dim_size; d += blockDim.x) { - sum += std::exp(static_cast(input[data_offset + d * dim_stride]) - - max_value); - } - if (blockDim.x > 1) { - sum = BlockReduceAlongDimX(sdata, sum); - } - - // 3. input-max-log_sum and write to output - for (int d = threadIdx.x; d < dim_size; d += blockDim.x) { - output[data_offset + d * dim_stride] = static_cast( - static_cast(input[data_offset + d * dim_stride]) - max_value - - std::log(sum)); - } - } - } -} - -// block.y covers inner_size. Threads along the x axis process dim_size -// elements, and make sure not to exceed the 1024 threads per block. -// Note that dim_threads namely blockDim.x is either 1 or a even number. -inline dim3 GetBlockSize(int dim_size, int inner_size) { - int inner_threads = inner_size; - inner_threads = std::min(inner_threads, 1024); - int dim_threads = 1; - - while (dim_threads * inner_threads <= 1024 && dim_threads <= dim_size) { - dim_threads *= 2; - } - dim_threads /= 2; - return dim3(dim_threads, inner_threads); -} - -// First cover the y axis as many blocks as possible. -// Then cover the x axis as many blocks as possible, -// and make sure not to exceed the max_active_blocks. -inline dim3 GetGridSize(dim3 block, int max_active_blocks, int outer_size, - int dim_size, int inner_size) { - int inner_blocks = (inner_size + block.y - 1) / block.y; - if (inner_blocks > max_active_blocks) inner_blocks = max_active_blocks; - - int outer_blocks = (max_active_blocks + inner_blocks - 1) / inner_blocks; - if (outer_blocks > outer_size) outer_blocks = outer_size; - return dim3(outer_blocks, inner_blocks); -} - -// When designing grid size and block size, priority is given to block size, -// and grid will be determined according to the maximum number of active blocks, -// which is set by as a experience value. -template -void ComputeLaunchConfigure(Kernel k, int outer_size, int dim_size, - int inner_size, dim3 &grid, dim3 &block, - int &shared_mem, int num_sm) { - block = GetBlockSize(dim_size, inner_size); - int block_threads = block.x * block.y; - shared_mem = block.x == 1 ? 0 : block_threads * sizeof(T); - int max_active_blocks = num_sm * 2; - grid = - GetGridSize(block, max_active_blocks, outer_size, dim_size, inner_size); -} - -template -void LaunchLogSoftmaxForwardCUDAKernelNotLastAxis(T *output_data, - const T *input_data, - int outer_size, int dim_size, - int inner_size, int num_sm, - gpuStream_t stream) { - int shared_mem; - dim3 grid; - dim3 block; - - ComputeLaunchConfigure( - &LogSoftmaxForwardCUDAKernelNotLastAxis, outer_size, dim_size, - inner_size, grid, block, shared_mem, num_sm); - - LogSoftmaxForwardCUDAKernelNotLastAxis< - T, MPDType><<>>( - output_data, input_data, outer_size, dim_size, inner_size); -} +using Tensor = framework::Tensor; template class LogSoftmaxKernel : public framework::OpKernel { - using MPDType = typename phi::dtype::MPTypeTrait::Type; - public: - void Compute(const framework::ExecutionContext &context) const override { - const auto *x = context.Input("X"); - auto *out = context.Output("Out"); - const auto *input_data = x->data(); - auto *output_data = out->mutable_data(context.GetPlace()); - - const int rank = x->dims().size(); - const int axis = CanonicalAxis(context.Attr("axis"), rank); + void Compute(const framework::ExecutionContext &ctx) const override { + auto *x = ctx.Input("X"); + auto *out = ctx.Output("Out"); + out->mutable_data(ctx.GetPlace()); - int dim_size = x->dims()[axis]; - int inner_size = 1; - for (int i = axis + 1; i < x->dims().size(); ++i) { - inner_size *= x->dims()[i]; - } - int outer_size = SizeToAxis(axis, x->dims()); - gpuStream_t stream = context.cuda_device_context().stream(); - int num_sm = context.cuda_device_context().GetSMCount(); - - if (inner_size == 1 && dim_size <= 1024 && dim_size * sizeof(T) <= 4096) { - LaunchSoftmaxForwardForLastAxis(output_data, input_data, - dim_size, outer_size, stream); - } else { - LaunchLogSoftmaxForwardCUDAKernelNotLastAxis( - output_data, input_data, outer_size, dim_size, inner_size, num_sm, - stream); - } + int input_axis = ctx.Attr("axis"); + auto &dev_ctx = ctx.template device_context(); + phi::SoftmaxForwardCUDAKernelDriver(dev_ctx, *x, input_axis, out); } }; -// Backward below -#define LAUNCH_WARP_BACKWARD_COMPUTE(near_greater_power_of_two) \ - case near_greater_power_of_two: \ - ComputeLogSoftmaxBackwardInWarp< \ - T, AccT, near_greater_power_of_two><<>>( \ - output, grad_output, grad_input, outer_size, dim_size); \ - break; - -template -__global__ void ComputeLogSoftmaxBackwardInWarp(const T *output, - const T *grad_output, - T *grad_input, int batch_size, - int element_count) { - constexpr int near_greater_power_of_two = NearGreaterPowerOfTwo; - constexpr int kernel_warp_size = - (near_greater_power_of_two < 32) ? near_greater_power_of_two : 32; - constexpr int warp_iter = near_greater_power_of_two / kernel_warp_size; - int batch_id = blockDim.y * blockIdx.x + threadIdx.y; - - int thread_in_warp_idx = threadIdx.x; - - // 1.read data from global memory to registers - AccT output_register[warp_iter]; - AccT grad_output_register[warp_iter]; - int effective_element_count = (batch_id < batch_size) ? element_count : 0; - for (int iter = 0; iter < warp_iter; ++iter) { - int element_index = thread_in_warp_idx + iter * kernel_warp_size; - if (element_index < effective_element_count) { - output_register[iter] = - static_cast(output[batch_id * element_count + element_index]); - grad_output_register[iter] = static_cast( - grad_output[batch_id * element_count + element_index]); - } else { - output_register[iter] = static_cast(0); - grad_output_register[iter] = static_cast(0); - } - } - - // 2. For each warp, accumulate all thread registers - AccT sum = grad_output_register[0]; -#pragma unroll - for (int iter = 1; iter < warp_iter; ++iter) { - sum += grad_output_register[iter]; - } - sum = WarpReduceSum(sum); - -// 3. write result in grad_input -#pragma unroll - for (int iter = 0; iter < warp_iter; ++iter) { - int element_index = thread_in_warp_idx + iter * kernel_warp_size; - if (element_index < effective_element_count) { - grad_input[batch_id * element_count + element_index] = static_cast( - (grad_output_register[iter] - std::exp(output_register[iter]) * sum)); - } - } -} - -template -void LaunchSoftmaxBackwardForLastAxis(T *grad_input, const T *grad_output, - const T *output, int dim_size, - int outer_size, gpuStream_t stream) { - int threads_per_block = 128; - int near_greater_power_of_two = GetNearGreaterPowerOfTwo(dim_size); - int kernel_warp_size = - (near_greater_power_of_two < 32) ? near_greater_power_of_two : 32; - int warps_per_block = (threads_per_block / kernel_warp_size); - int blocks = (outer_size + warps_per_block - 1) / warps_per_block; - dim3 threads(kernel_warp_size, warps_per_block, 1); - - switch (near_greater_power_of_two) { - LAUNCH_WARP_BACKWARD_COMPUTE(1); // dim_size: 1 - LAUNCH_WARP_BACKWARD_COMPUTE(2); // dim_size: 2 - LAUNCH_WARP_BACKWARD_COMPUTE(4); // dim_size: 3~4 - LAUNCH_WARP_BACKWARD_COMPUTE(8); // dim_size: 5~8 - LAUNCH_WARP_BACKWARD_COMPUTE(16); // dim_size: 9~16 - LAUNCH_WARP_BACKWARD_COMPUTE(32); // dim_size: 17~32 - LAUNCH_WARP_BACKWARD_COMPUTE(64); // dim_size: 33~64 - LAUNCH_WARP_BACKWARD_COMPUTE(128); // dim_size: 65~128 - LAUNCH_WARP_BACKWARD_COMPUTE(256); // dim_size: 129~256 - LAUNCH_WARP_BACKWARD_COMPUTE(512); // dim_size: 257~512 - LAUNCH_WARP_BACKWARD_COMPUTE(1024); // dim_size: 513~1024 - - default: - break; - } -} - template class LogSoftmaxGradKernel : public framework::OpKernel { - using MPDType = typename phi::dtype::MPTypeTrait::Type; - public: - void Compute(const framework::ExecutionContext &context) const override { - const auto *out = context.Input("Out"); - const auto *d_out = - context.Input(framework::GradVarName("Out")); - auto *d_x = context.Output(framework::GradVarName("X")); + void Compute(const framework::ExecutionContext &ctx) const override { + auto *out = ctx.Input("Out"); + auto *dout = ctx.Input(framework::GradVarName("Out")); + auto *dx = ctx.Output(framework::GradVarName("X")); + dx->mutable_data(ctx.GetPlace()); - const auto *out_data = out->data(); - const auto *d_out_data = d_out->data(); - auto *d_x_data = d_x->mutable_data(context.GetPlace()); - - const int rank = out->dims().size(); - const int axis = CanonicalAxis(context.Attr("axis"), rank); - - int dim_size = out->dims()[axis]; - int inner_size = 1; - for (int i = axis + 1; i < out->dims().size(); ++i) { - inner_size *= out->dims()[i]; - } - int outer_size = SizeToAxis(axis, out->dims()); - gpuStream_t stream = context.cuda_device_context().stream(); - - if (inner_size == 1 && dim_size <= 1024 && dim_size * sizeof(T) <= 4096) { - LaunchSoftmaxBackwardForLastAxis( - d_x_data, d_out_data, out_data, dim_size, outer_size, stream); - } else { - LogSoftmaxGradFunctor()( - context.template device_context(), out, - d_out, d_x, axis); - } + int input_axis = ctx.Attr("axis"); + auto &dev_ctx = ctx.template device_context(); + phi::SoftmaxBackwardCUDAKernelDriver(dev_ctx, *out, *dout, + input_axis, dx); } }; @@ -473,6 +57,17 @@ class LogSoftmaxGradKernel namespace ops = paddle::operators; namespace plat = paddle::platform; + +#ifdef PADDLE_WITH_HIP +REGISTER_OP_CUDA_KERNEL( + log_softmax, ops::LogSoftmaxKernel, + ops::LogSoftmaxKernel, + ops::LogSoftmaxKernel); +REGISTER_OP_CUDA_KERNEL( + log_softmax_grad, ops::LogSoftmaxGradKernel, + ops::LogSoftmaxGradKernel, + ops::LogSoftmaxGradKernel); +#else REGISTER_OP_CUDA_KERNEL( log_softmax, ops::LogSoftmaxKernel, ops::LogSoftmaxKernel, @@ -483,3 +78,4 @@ REGISTER_OP_CUDA_KERNEL( ops::LogSoftmaxGradKernel, ops::LogSoftmaxGradKernel, ops::LogSoftmaxGradKernel); +#endif diff --git a/paddle/phi/kernels/gpudnn/softmax_gpudnn.h b/paddle/phi/kernels/gpudnn/softmax_gpudnn.h index 0352fdf6fa2..2b2dd511896 100644 --- a/paddle/phi/kernels/gpudnn/softmax_gpudnn.h +++ b/paddle/phi/kernels/gpudnn/softmax_gpudnn.h @@ -351,8 +351,17 @@ __global__ void WarpSoftmaxForward(T* softmax, VecT* softmax_v = reinterpret_cast(&softmax[(first_batch + i) * stride]); VecT* reg_v = reinterpret_cast(&out_tmp[i][0][0]); - kps::ElementwiseUnary>( - &out_tmp[i][0][0], &srcdata[i][0][0], UnaryDivFunctor(sum[i])); + if (LogMode) { + kps::ElementwiseUnary>( + &srcdata[i][0][0], &srcdata[i][0][0], UnaryLogFunctor()); + kps::ElementwiseUnary>( + &out_tmp[i][0][0], + &srcdata[i][0][0], + UnarySubFunctor(std::log(sum[i]))); + } else { + kps::ElementwiseUnary>( + &out_tmp[i][0][0], &srcdata[i][0][0], UnaryDivFunctor(sum[i])); + } kps::WriteData( &softmax_v[0], ®_v[0], idx_max_v[i], 0, kWarpSize, 1); } @@ -434,15 +443,25 @@ __global__ void WarpSoftmaxBackward(T* dst, AccT sum_tmp[kBatchSize][kLoopsV][kVSize]; AccT* gradptr = reinterpret_cast(&grad_tmp[0][0][0]); AccT* srcptr = reinterpret_cast(&src_tmp[0][0][0]); - kps::ElementwiseBinary>( - &sum_tmp[0][0][0], &gradptr[0], &srcptr[0], kps::MulFunctor()); - kps::Reduce, - kps::details::ReduceMode::kLocalMode>( - &sum[0], &sum_tmp[0][0][0], kps::AddFunctor(), true); + if (LogMode) { + kps::Reduce, + kps::details::ReduceMode::kLocalMode>( + &sum[0], &grad_tmp[0][0][0], kps::AddFunctor(), true); + } else { + kps::ElementwiseBinary>( + &sum_tmp[0][0][0], &gradptr[0], &srcptr[0], kps::MulFunctor()); + kps::Reduce, + kps::details::ReduceMode::kLocalMode>( + &sum[0], &sum_tmp[0][0][0], kps::AddFunctor(), true); + } WarpReduceSum(sum); // write result to global memory @@ -453,10 +472,23 @@ __global__ void WarpSoftmaxBackward(T* dst, if (i >= local_batches) break; AccT* gradptr = reinterpret_cast(&grad_tmp[i][0][0]); AccT* srcptr = reinterpret_cast(&src_tmp[i][0][0]); - kps::ElementwiseUnary>( - &out[i][0][0], &gradptr[0], UnarySubFunctor(sum[i])); - kps::ElementwiseBinary>( - &out_tmp[i][0][0], &srcptr[0], &out[i][0][0], kps::MulFunctor()); + if (LogMode) { + kps::ElementwiseUnary>( + &out[i][0][0], &srcptr[0], ExpMulFunctor(sum[i])); + kps::ElementwiseBinary>( + &out_tmp[i][0][0], + &gradptr[0], + &out[i][0][0], + kps::SubFunctor()); + } else { + kps::ElementwiseUnary>( + &out[i][0][0], &gradptr[0], UnarySubFunctor(sum[i])); + kps::ElementwiseBinary>( + &out_tmp[i][0][0], + &srcptr[0], + &out[i][0][0], + kps::MulFunctor()); + } VecT* dst_v = reinterpret_cast(&dst[(first_batch + i) * stride]); VecT* reg_v = reinterpret_cast(&out_tmp[i][0][0]); kps::WriteData( @@ -639,7 +671,8 @@ __global__ void NormalSoftmaxForward( template class Functor> + template class Functor, + bool LogMode> __global__ void NormalSoftmaxBackward(T* input_grad, const T* output_grad, const T* output, @@ -656,10 +689,17 @@ __global__ void NormalSoftmaxBackward(T* input_grad, // 1. reduce sum AccT sum = 0; - for (int mid_id = threadIdx.y; mid_id < mid_dim; mid_id += blockDim.y) { - int data_offset = grad_offset + mid_id * mid_stride; - sum += static_cast(output_grad[data_offset]) * - static_cast(output[data_offset]); + if (LogMode) { + for (int mid_id = threadIdx.y; mid_id < mid_dim; mid_id += blockDim.y) { + int data_offset = grad_offset + mid_id * mid_stride; + sum += static_cast(output_grad[data_offset]); + } + } else { + for (int mid_id = threadIdx.y; mid_id < mid_dim; mid_id += blockDim.y) { + int data_offset = grad_offset + mid_id * mid_stride; + sum += static_cast(output_grad[data_offset]) * + static_cast(output[data_offset]); + } } if (blockDim.y > 1) { kps::Reduce, kMode::kGlobalMode>( @@ -715,10 +755,10 @@ void LaunchNormalSoftmaxBackward(const GPUContext& dev_ctx, dim3 grid, block; GetLaunchConfig(high_dim, mid_dim, low_dim, &grid, &block); if (LogMode) { - NormalSoftmaxBackward< - T, - AccT, - LogSoftmaxBackwardFunctor><<>>( + NormalSoftmaxBackward<<>>( input_grad_data, output_grad_data, output_data, @@ -726,10 +766,10 @@ void LaunchNormalSoftmaxBackward(const GPUContext& dev_ctx, mid_dim, low_dim); } else { - NormalSoftmaxBackward< - T, - AccT, - SoftmaxBackwardFunctor><<>>( + NormalSoftmaxBackward<<>>( input_grad_data, output_grad_data, output_data, @@ -864,6 +904,32 @@ static bool CanUseCudnnSoftmax(const GPUContext& dev_ctx) { return false; } +#if CUDNN_VERSION < 8100 +template <> +inline void SoftmaxForwardCudnnKernel( + const GPUContext& dev_ctx, + const DenseTensor& x, + const int axis, + const bool log_mode, + DenseTensor* out) { + PADDLE_THROW(errors::Unavailable( + "This kernel is not supported when the dtype is bf16 and CUDNN_VERSION < " + "8100.")); +} +template <> +inline void SoftmaxBackwardCudnnKernel( + const GPUContext& dev_ctx, + const DenseTensor& out, + const DenseTensor& dout, + const int axis, + const bool log_mode, + DenseTensor* dx) { + PADDLE_THROW(errors::Unavailable( + "This kernel is not supported when the dtype is bf16 and CUDNN_VERSION < " + "8100.")); +} +#endif + template void SoftmaxForwardCUDAKernelDriver(const GPUContext& dev_ctx, const DenseTensor& x, -- GitLab