diff --git a/paddle/fluid/operators/log_softmax_op.cu b/paddle/fluid/operators/log_softmax_op.cu index 12c607adb44f4e9590bd5a50305c9d6fd5b3d1d7..7c47ad90502ebd1f1aa0524110c501f38034b936 100644 --- a/paddle/fluid/operators/log_softmax_op.cu +++ b/paddle/fluid/operators/log_softmax_op.cu @@ -15,6 +15,7 @@ #include #include "paddle/fluid/operators/amp/fp16_type_traits.h" #include "paddle/fluid/operators/log_softmax_op.h" +#include "paddle/fluid/operators/math/functors.h" #include "paddle/fluid/platform/cuda_device_function.h" namespace paddle { @@ -142,6 +143,170 @@ void LaunchSoftmaxForwardForLastAxis(T *dst, const T *src, int dim_size, } } +// Returns the final item after reduce operation along block.x. +// Firstly, get shared memory(smem) offset, find the starting position for every +// y. +// Secondly, initialise every smem position with value 'val' of thread itself. +// Thirdly, apply standard reduction along x direction as below: +// +// -> x direction +// [o o o o o o o o] time 0 +// | |/ / +// | /| / +// | / | / +// |/ |/ +// [o o o o x x x x] time 1 +// | |/ / +// |/|/ +// [o o x x x x x x] time 2 +// |/ +// [o x x x x x x x] time 3 +// +// Finally, return the first item. +// Imaging multiple reductions executed in paralell along y axis, +// Note that when blockDim.x is not 1, it's a EVEN number in all cases, +// and the size of shared memory is even as well. +template class Functor> +__forceinline__ __device__ T BlockReduceAlongDimX(T *shared, T val) { + Functor func; + // This reduction is not Block-wise reduction, only reduce along block.x. + // therefore the shared mem has offsets for different block.y. + shared += threadIdx.y * blockDim.x; + shared[threadIdx.x] = val; + int offset = blockDim.x / 2; + + while (offset > 0) { + __syncthreads(); + if (threadIdx.x < offset) { + shared[threadIdx.x] = + func(shared[threadIdx.x], shared[threadIdx.x + offset]); + } + offset /= 2; + } + __syncthreads(); + return shared[0]; +} + +template +__global__ void LogSoftmaxForwardCUDAKernelNotLastAxis( + T *output, const T *input, int outer_size, int dim_size, int inner_size) { + extern __shared__ unsigned char smem[]; + auto sdata = reinterpret_cast(smem); + + const int outer_stride = inner_size * dim_size; + const int dim_stride = inner_size; + + for (int x_id = blockIdx.x; x_id < outer_size; x_id += gridDim.x) { + for (int y_id = blockIdx.y * blockDim.y + threadIdx.y; y_id < inner_size; + y_id += blockDim.y * gridDim.y) { + const int data_offset = x_id * outer_stride + y_id; + // When blockDim.x==1, no block.x-reduction opetaions are needed. + // And threadIdx.x is 0 all the time, so the for-loops below are literally + // loops (No parallel executions). Loop all elements along axis and + // calculate the Max, Sum and (input[id]-Max-log(Sum)) to get the final + // log_softmax values along that axis. + // 1. reduce max + AccT max_value = -std::numeric_limits::infinity(); + // For one thread, iterate all items it responsable for, and get + // max_value. + // If there are N threads, N max_value will be returned. + for (int d = threadIdx.x; d < dim_size; d += blockDim.x) { + const AccT value = + static_cast(input[data_offset + d * dim_stride]); + max_value = math::MaxFunctor()(max_value, value); + } + // If there are more than 1 threads along block x, reduce all max_values + // and get the global max_value, which is the max value along "axis". + // If there is only one thread along block x, no need to reduce, as the + // 'max_value' is the global max_value. + if (blockDim.x > 1) { + max_value = + BlockReduceAlongDimX(sdata, max_value); + } + + // 2. reduce sum + AccT sum = 0; + // Below is the same execution as '1. reduce max' + for (int d = threadIdx.x; d < dim_size; d += blockDim.x) { + sum += std::exp(static_cast(input[data_offset + d * dim_stride]) - + max_value); + } + if (blockDim.x > 1) { + sum = BlockReduceAlongDimX(sdata, sum); + } + + // 3. input-max-log_sum and write to output + for (int d = threadIdx.x; d < dim_size; d += blockDim.x) { + output[data_offset + d * dim_stride] = static_cast( + static_cast(input[data_offset + d * dim_stride]) - max_value - + std::log(sum)); + } + } + } +} + +// block.y covers inner_size. Threads along the x axis process dim_size +// elements, and make sure not to exceed the 1024 threads per block. +// Note that dim_threads namely blockDim.x is either 1 or a even number. +inline dim3 GetBlockSize(int dim_size, int inner_size) { + int inner_threads = inner_size; + inner_threads = std::min(inner_threads, 1024); + int dim_threads = 1; + + while (dim_threads * inner_threads <= 1024 && dim_threads <= dim_size) { + dim_threads *= 2; + } + dim_threads /= 2; + return dim3(dim_threads, inner_threads); +} + +// First cover the y axis as many blocks as possible. +// Then cover the x axis as many blocks as possible, +// and make sure not to exceed the max_active_blocks. +inline dim3 GetGridSize(dim3 block, int max_active_blocks, int outer_size, + int dim_size, int inner_size) { + int inner_blocks = (inner_size + block.y - 1) / block.y; + if (inner_blocks > max_active_blocks) inner_blocks = max_active_blocks; + + int outer_blocks = (max_active_blocks + inner_blocks - 1) / inner_blocks; + if (outer_blocks > outer_size) outer_blocks = outer_size; + return dim3(outer_blocks, inner_blocks); +} + +// When designing grid size and block size, priority is given to block size, +// and grid will be determined according to the maximum number of active blocks, +// which is set by as a experience value. +template +void ComputeLaunchConfigure(Kernel k, int outer_size, int dim_size, + int inner_size, dim3 &grid, dim3 &block, + int &shared_mem, int num_sm) { + block = GetBlockSize(dim_size, inner_size); + int block_threads = block.x * block.y; + shared_mem = block.x == 1 ? 0 : block_threads * sizeof(T); + int max_active_blocks = num_sm * 2; + grid = + GetGridSize(block, max_active_blocks, outer_size, dim_size, inner_size); +} + +template +void LaunchLogSoftmaxForwardCUDAKernelNotLastAxis(T *output_data, + const T *input_data, + int outer_size, int dim_size, + int inner_size, int num_sm, + gpuStream_t stream) { + int shared_mem; + dim3 grid; + dim3 block; + + ComputeLaunchConfigure( + &LogSoftmaxForwardCUDAKernelNotLastAxis, outer_size, dim_size, + inner_size, grid, block, shared_mem, num_sm); + + LogSoftmaxForwardCUDAKernelNotLastAxis< + T, MPDType><<>>( + output_data, input_data, outer_size, dim_size, inner_size); +} + template class LogSoftmaxKernel : public framework::OpKernel { @@ -164,14 +329,15 @@ class LogSoftmaxKernel } int outer_size = SizeToAxis(axis, x->dims()); gpuStream_t stream = context.cuda_device_context().stream(); + int num_sm = context.cuda_device_context().GetSMCount(); if (inner_size == 1 && dim_size <= 1024 && dim_size * sizeof(T) <= 4096) { LaunchSoftmaxForwardForLastAxis(output_data, input_data, dim_size, outer_size, stream); } else { - LogSoftmaxFunctor()( - context.template device_context(), x, - out, axis); + LaunchLogSoftmaxForwardCUDAKernelNotLastAxis( + output_data, input_data, outer_size, dim_size, inner_size, num_sm, + stream); } } }; @@ -195,7 +361,7 @@ __global__ void ComputeLogSoftmaxBackwardInWarp(const T *output, constexpr int warp_iter = near_greater_power_of_two / kernel_warp_size; int batch_id = blockDim.y * blockIdx.x + threadIdx.y; - int thread_in_warp_idx = threadIdx.x % kernel_warp_size; + int thread_in_warp_idx = threadIdx.x; // 1.read data from global memory to registers AccT output_register[warp_iter]; @@ -209,8 +375,8 @@ __global__ void ComputeLogSoftmaxBackwardInWarp(const T *output, grad_output_register[iter] = static_cast( grad_output[batch_id * element_count + element_index]); } else { - output_register[iter] = AccT(0); - grad_output_register[iter] = AccT(0); + output_register[iter] = static_cast(0); + grad_output_register[iter] = static_cast(0); } } @@ -271,13 +437,13 @@ class LogSoftmaxGradKernel public: void Compute(const framework::ExecutionContext &context) const override { const auto *out = context.Input("Out"); - const auto *g_out = + const auto *d_out = context.Input(framework::GradVarName("Out")); - auto *g_x = context.Output(framework::GradVarName("X")); + auto *d_x = context.Output(framework::GradVarName("X")); const auto *out_data = out->data(); - const auto *g_out_data = g_out->data(); - auto *g_x_data = g_x->mutable_data(context.GetPlace()); + const auto *d_out_data = d_out->data(); + auto *d_x_data = d_x->mutable_data(context.GetPlace()); const int rank = out->dims().size(); const int axis = CanonicalAxis(context.Attr("axis"), rank); @@ -292,11 +458,11 @@ class LogSoftmaxGradKernel if (inner_size == 1 && dim_size <= 1024 && dim_size * sizeof(T) <= 4096) { LaunchSoftmaxBackwardForLastAxis( - g_x_data, g_out_data, out_data, dim_size, outer_size, stream); + d_x_data, d_out_data, out_data, dim_size, outer_size, stream); } else { LogSoftmaxGradFunctor()( context.template device_context(), out, - g_out, g_x, axis); + d_out, d_x, axis); } } }; diff --git a/paddle/fluid/operators/math/functors.h b/paddle/fluid/operators/math/functors.h index 2eb6d0093538939957b22d06ed85fbeb0fd01a55..054018b10e87e421c45846abf550f0f7a552f6a3 100644 --- a/paddle/fluid/operators/math/functors.h +++ b/paddle/fluid/operators/math/functors.h @@ -41,6 +41,11 @@ struct AddFunctor { inline HOSTDEVICE T operator()(T x, T y) { return x + y; } }; +template +struct MaxFunctor { + inline HOSTDEVICE T operator()(T a, T b) const { return a < b ? b : a; } +}; + template struct AddGradFunctor { inline HOSTDEVICE T Dx(T x, T y) { return static_cast(1.); }