未验证 提交 69ffb386 编写于 作者: L Lijunhui 提交者: GitHub

Optimize the forward of log_softmax for the case when axis is not the last dimention. (#32396)

上级 389f8c5e
......@@ -15,6 +15,7 @@
#include <limits>
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
#include "paddle/fluid/operators/log_softmax_op.h"
#include "paddle/fluid/operators/math/functors.h"
#include "paddle/fluid/platform/cuda_device_function.h"
namespace paddle {
......@@ -142,6 +143,170 @@ void LaunchSoftmaxForwardForLastAxis(T *dst, const T *src, int dim_size,
}
}
// Returns the final item after reduce operation along block.x.
// Firstly, get shared memory(smem) offset, find the starting position for every
// y.
// Secondly, initialise every smem position with value 'val' of thread itself.
// Thirdly, apply standard reduction along x direction as below:
//
// -> x direction
// [o o o o o o o o] time 0
// | |/ /
// | /| /
// | / | /
// |/ |/
// [o o o o x x x x] time 1
// | |/ /
// |/|/
// [o o x x x x x x] time 2
// |/
// [o x x x x x x x] time 3
//
// Finally, return the first item.
// Imaging multiple reductions executed in paralell along y axis,
// Note that when blockDim.x is not 1, it's a EVEN number in all cases,
// and the size of shared memory is even as well.
template <typename T, template <typename> class Functor>
__forceinline__ __device__ T BlockReduceAlongDimX(T *shared, T val) {
Functor<T> func;
// This reduction is not Block-wise reduction, only reduce along block.x.
// therefore the shared mem has offsets for different block.y.
shared += threadIdx.y * blockDim.x;
shared[threadIdx.x] = val;
int offset = blockDim.x / 2;
while (offset > 0) {
__syncthreads();
if (threadIdx.x < offset) {
shared[threadIdx.x] =
func(shared[threadIdx.x], shared[threadIdx.x + offset]);
}
offset /= 2;
}
__syncthreads();
return shared[0];
}
template <typename T, typename AccT>
__global__ void LogSoftmaxForwardCUDAKernelNotLastAxis(
T *output, const T *input, int outer_size, int dim_size, int inner_size) {
extern __shared__ unsigned char smem[];
auto sdata = reinterpret_cast<AccT *>(smem);
const int outer_stride = inner_size * dim_size;
const int dim_stride = inner_size;
for (int x_id = blockIdx.x; x_id < outer_size; x_id += gridDim.x) {
for (int y_id = blockIdx.y * blockDim.y + threadIdx.y; y_id < inner_size;
y_id += blockDim.y * gridDim.y) {
const int data_offset = x_id * outer_stride + y_id;
// When blockDim.x==1, no block.x-reduction opetaions are needed.
// And threadIdx.x is 0 all the time, so the for-loops below are literally
// loops (No parallel executions). Loop all elements along axis and
// calculate the Max, Sum and (input[id]-Max-log(Sum)) to get the final
// log_softmax values along that axis.
// 1. reduce max
AccT max_value = -std::numeric_limits<AccT>::infinity();
// For one thread, iterate all items it responsable for, and get
// max_value.
// If there are N threads, N max_value will be returned.
for (int d = threadIdx.x; d < dim_size; d += blockDim.x) {
const AccT value =
static_cast<AccT>(input[data_offset + d * dim_stride]);
max_value = math::MaxFunctor<AccT>()(max_value, value);
}
// If there are more than 1 threads along block x, reduce all max_values
// and get the global max_value, which is the max value along "axis".
// If there is only one thread along block x, no need to reduce, as the
// 'max_value' is the global max_value.
if (blockDim.x > 1) {
max_value =
BlockReduceAlongDimX<AccT, math::MaxFunctor>(sdata, max_value);
}
// 2. reduce sum
AccT sum = 0;
// Below is the same execution as '1. reduce max'
for (int d = threadIdx.x; d < dim_size; d += blockDim.x) {
sum += std::exp(static_cast<AccT>(input[data_offset + d * dim_stride]) -
max_value);
}
if (blockDim.x > 1) {
sum = BlockReduceAlongDimX<AccT, math::AddFunctor>(sdata, sum);
}
// 3. input-max-log_sum and write to output
for (int d = threadIdx.x; d < dim_size; d += blockDim.x) {
output[data_offset + d * dim_stride] = static_cast<T>(
static_cast<AccT>(input[data_offset + d * dim_stride]) - max_value -
std::log(sum));
}
}
}
}
// block.y covers inner_size. Threads along the x axis process dim_size
// elements, and make sure not to exceed the 1024 threads per block.
// Note that dim_threads namely blockDim.x is either 1 or a even number.
inline dim3 GetBlockSize(int dim_size, int inner_size) {
int inner_threads = inner_size;
inner_threads = std::min(inner_threads, 1024);
int dim_threads = 1;
while (dim_threads * inner_threads <= 1024 && dim_threads <= dim_size) {
dim_threads *= 2;
}
dim_threads /= 2;
return dim3(dim_threads, inner_threads);
}
// First cover the y axis as many blocks as possible.
// Then cover the x axis as many blocks as possible,
// and make sure not to exceed the max_active_blocks.
inline dim3 GetGridSize(dim3 block, int max_active_blocks, int outer_size,
int dim_size, int inner_size) {
int inner_blocks = (inner_size + block.y - 1) / block.y;
if (inner_blocks > max_active_blocks) inner_blocks = max_active_blocks;
int outer_blocks = (max_active_blocks + inner_blocks - 1) / inner_blocks;
if (outer_blocks > outer_size) outer_blocks = outer_size;
return dim3(outer_blocks, inner_blocks);
}
// When designing grid size and block size, priority is given to block size,
// and grid will be determined according to the maximum number of active blocks,
// which is set by as a experience value.
template <typename T, typename Kernel>
void ComputeLaunchConfigure(Kernel k, int outer_size, int dim_size,
int inner_size, dim3 &grid, dim3 &block,
int &shared_mem, int num_sm) {
block = GetBlockSize(dim_size, inner_size);
int block_threads = block.x * block.y;
shared_mem = block.x == 1 ? 0 : block_threads * sizeof(T);
int max_active_blocks = num_sm * 2;
grid =
GetGridSize(block, max_active_blocks, outer_size, dim_size, inner_size);
}
template <typename T, typename MPDType>
void LaunchLogSoftmaxForwardCUDAKernelNotLastAxis(T *output_data,
const T *input_data,
int outer_size, int dim_size,
int inner_size, int num_sm,
gpuStream_t stream) {
int shared_mem;
dim3 grid;
dim3 block;
ComputeLaunchConfigure<MPDType>(
&LogSoftmaxForwardCUDAKernelNotLastAxis<T, MPDType>, outer_size, dim_size,
inner_size, grid, block, shared_mem, num_sm);
LogSoftmaxForwardCUDAKernelNotLastAxis<
T, MPDType><<<grid, block, shared_mem, stream>>>(
output_data, input_data, outer_size, dim_size, inner_size);
}
template <typename T>
class LogSoftmaxKernel<platform::CUDADeviceContext, T>
: public framework::OpKernel<T> {
......@@ -164,14 +329,15 @@ class LogSoftmaxKernel<platform::CUDADeviceContext, T>
}
int outer_size = SizeToAxis(axis, x->dims());
gpuStream_t stream = context.cuda_device_context().stream();
int num_sm = context.cuda_device_context().GetSMCount();
if (inner_size == 1 && dim_size <= 1024 && dim_size * sizeof(T) <= 4096) {
LaunchSoftmaxForwardForLastAxis<T, MPDType>(output_data, input_data,
dim_size, outer_size, stream);
} else {
LogSoftmaxFunctor<platform::CUDADeviceContext, T>()(
context.template device_context<platform::CUDADeviceContext>(), x,
out, axis);
LaunchLogSoftmaxForwardCUDAKernelNotLastAxis<T, MPDType>(
output_data, input_data, outer_size, dim_size, inner_size, num_sm,
stream);
}
}
};
......@@ -195,7 +361,7 @@ __global__ void ComputeLogSoftmaxBackwardInWarp(const T *output,
constexpr int warp_iter = near_greater_power_of_two / kernel_warp_size;
int batch_id = blockDim.y * blockIdx.x + threadIdx.y;
int thread_in_warp_idx = threadIdx.x % kernel_warp_size;
int thread_in_warp_idx = threadIdx.x;
// 1.read data from global memory to registers
AccT output_register[warp_iter];
......@@ -209,8 +375,8 @@ __global__ void ComputeLogSoftmaxBackwardInWarp(const T *output,
grad_output_register[iter] = static_cast<AccT>(
grad_output[batch_id * element_count + element_index]);
} else {
output_register[iter] = AccT(0);
grad_output_register[iter] = AccT(0);
output_register[iter] = static_cast<AccT>(0);
grad_output_register[iter] = static_cast<AccT>(0);
}
}
......@@ -271,13 +437,13 @@ class LogSoftmaxGradKernel<platform::CUDADeviceContext, T>
public:
void Compute(const framework::ExecutionContext &context) const override {
const auto *out = context.Input<framework::Tensor>("Out");
const auto *g_out =
const auto *d_out =
context.Input<framework::Tensor>(framework::GradVarName("Out"));
auto *g_x = context.Output<framework::Tensor>(framework::GradVarName("X"));
auto *d_x = context.Output<framework::Tensor>(framework::GradVarName("X"));
const auto *out_data = out->data<T>();
const auto *g_out_data = g_out->data<T>();
auto *g_x_data = g_x->mutable_data<T>(context.GetPlace());
const auto *d_out_data = d_out->data<T>();
auto *d_x_data = d_x->mutable_data<T>(context.GetPlace());
const int rank = out->dims().size();
const int axis = CanonicalAxis(context.Attr<int>("axis"), rank);
......@@ -292,11 +458,11 @@ class LogSoftmaxGradKernel<platform::CUDADeviceContext, T>
if (inner_size == 1 && dim_size <= 1024 && dim_size * sizeof(T) <= 4096) {
LaunchSoftmaxBackwardForLastAxis<T, MPDType>(
g_x_data, g_out_data, out_data, dim_size, outer_size, stream);
d_x_data, d_out_data, out_data, dim_size, outer_size, stream);
} else {
LogSoftmaxGradFunctor<platform::CUDADeviceContext, T>()(
context.template device_context<platform::CUDADeviceContext>(), out,
g_out, g_x, axis);
d_out, d_x, axis);
}
}
};
......
......@@ -41,6 +41,11 @@ struct AddFunctor {
inline HOSTDEVICE T operator()(T x, T y) { return x + y; }
};
template <typename T>
struct MaxFunctor {
inline HOSTDEVICE T operator()(T a, T b) const { return a < b ? b : a; }
};
template <typename T>
struct AddGradFunctor {
inline HOSTDEVICE T Dx(T x, T y) { return static_cast<T>(1.); }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册