未验证 提交 8e1b0204 编写于 作者: Z Zhang Zheng 提交者: GitHub

Optimize performance of softmax_fwd when axis!=-1 (#38602)

* Optimize performence of softmax_fwd when axis!=-1

* use functor

* support hip

* fix functor
上级 b292dfb8
......@@ -186,6 +186,58 @@ struct UnaryDivFunctor {
Tx n_inv;
};
template <typename Tx, typename Ty = Tx>
struct SoftmaxForwardFunctor {
HOSTDEVICE inline SoftmaxForwardFunctor(Tx max, Tx sum)
: max(max), sum(sum) {}
HOSTDEVICE inline Ty operator()(const Tx& x) const {
return static_cast<Ty>(std::exp(x - max) / sum);
}
private:
Tx max;
Tx sum;
};
template <typename Tx, typename Ty = Tx>
struct SoftmaxBackwardFunctor {
HOSTDEVICE inline SoftmaxBackwardFunctor(Tx sum) : sum(sum) {}
HOSTDEVICE inline Ty operator()(const Tx& grad_out, const Tx& out) const {
return static_cast<Ty>(out * (grad_out - sum));
}
private:
Tx sum;
};
template <typename Tx, typename Ty = Tx>
struct LogSoftmaxForwardFunctor {
HOSTDEVICE inline LogSoftmaxForwardFunctor(Tx max, Tx sum)
: max(max), log_sum(std::log(sum)) {}
HOSTDEVICE inline Ty operator()(const Tx& x) const {
return static_cast<Ty>(x - max - log_sum);
}
private:
Tx max;
Tx log_sum;
};
template <typename Tx, typename Ty = Tx>
struct LogSoftmaxBackwardFunctor {
HOSTDEVICE inline LogSoftmaxBackwardFunctor(Tx sum) : sum(sum) {}
HOSTDEVICE inline Ty operator()(const Tx& grad_out, const Tx& out) const {
return static_cast<Ty>(grad_out - std::exp(out) * sum);
}
private:
Tx sum;
};
/*
Core function of computing softmax forward for axis=-1.
The computation includes
......@@ -255,7 +307,8 @@ __global__ void WarpSoftmaxForward(T* softmax, const T* src,
ReduceMaxFunctor<AccT>(), true);
WarpReduceMax<AccT, kBatchSize, kWarpSize>(max);
// compute sum
// compute sum
#pragma unroll
for (int i = 0; i < kBatchSize; ++i) {
kps::ElementwiseUnary<AccT, AccT, kVItem, 1, 1, ExpSubFunctor<AccT>>(
&srcdata[i][0][0], &srcdata[i][0][0], ExpSubFunctor<AccT>(max[i]));
......@@ -443,6 +496,113 @@ void SwitchWarpSoftmaxBackward(const int blocks, const dim3 threads,
#undef SOFTMAX_WARP_FORWARD_CASE
#undef SOFTMAX_WARP_BACKWARD_CASE
/**
* <NormalSoftmaxKernel>
* Better performence when axis != -1
*/
static void GetGridDim(int high_dim, int mid_dim, int low_dim,
const dim3& block, dim3* grid) {
int device_id = paddle::platform::GetCurrentDeviceId();
int max_mp = paddle::platform::GetGPUMultiProcessors(device_id);
int max_threads_per_mp =
paddle::platform::GetGPUMaxThreadsPerMultiProcessor(device_id);
int max_threads = max_threads_per_mp * max_mp;
int num_threads = block.x * block.y;
int max_num_blocks = max_threads / num_threads;
int grid_x = (low_dim + block.x - 1) / block.x;
grid_x = std::min(grid_x, max_num_blocks);
int grid_y = (max_num_blocks + grid_x - 1) / grid_x;
grid_y = std::min(grid_y, high_dim);
grid->x = grid_x;
grid->y = grid_y;
}
static void GetBlockDim(int mid_dim, int low_dim, dim3* block) {
#ifdef __HIPCC__
constexpr int max_num_threads = 256;
#else
constexpr int max_num_threads = 1024;
#endif
int block_x = 1 << log2_ceil(low_dim);
int block_y = 1 << log2_ceil(mid_dim);
block->x = std::min(block_x, 32);
block->y = std::min(block_y, static_cast<int>(max_num_threads / block->x));
block->x = std::min(block_x, static_cast<int>(max_num_threads / block->y));
}
static void GetLaunchConfig(int high_dim, int mid_dim, int low_dim, dim3* grid,
dim3* block) {
GetBlockDim(mid_dim, low_dim, block);
GetGridDim(high_dim, mid_dim, low_dim, *block, grid);
}
template <typename T, typename AccT,
template <typename, typename> class Functor>
__global__ void NormalSoftmaxForward(T* output, const T* input, int high_dim,
int mid_dim, int low_dim) {
using kMode = kps::details::ReduceMode;
const int high_stride = mid_dim * low_dim;
const int mid_stride = low_dim;
for (int high_id = blockIdx.y; high_id < high_dim; high_id += gridDim.y) {
for (int low_id = blockIdx.x * blockDim.x + threadIdx.x; low_id < low_dim;
low_id += blockDim.x * gridDim.x) {
const int input_offset = high_id * high_stride + low_id;
// 1. reduce max
AccT max_value = -std::numeric_limits<AccT>::infinity();
AccT value = -std::numeric_limits<AccT>::infinity();
for (int mid_id = threadIdx.y; mid_id < mid_dim; mid_id += blockDim.y) {
value = static_cast<AccT>(input[input_offset + mid_id * mid_stride]);
max_value = kps::MaxFunctor<AccT>()(max_value, value);
}
if (blockDim.y > 1) {
kps::Reduce<AccT, 1, 1, 1, kps::MaxFunctor<AccT>, kMode::kGlobalMode>(
&max_value, &max_value, kps::MaxFunctor<AccT>(), false);
}
// 2. reduce sum
AccT sum = 0;
for (int mid_id = threadIdx.y; mid_id < mid_dim; mid_id += blockDim.y) {
value = static_cast<AccT>(input[input_offset + mid_id * mid_stride]);
sum += std::exp(value - max_value);
}
if (blockDim.y > 1) {
kps::Reduce<AccT, 1, 1, 1, kps::AddFunctor<AccT>, kMode::kGlobalMode>(
&sum, &sum, kps::AddFunctor<AccT>(), false);
}
// 3. (log)softmax
Functor<AccT, T> functor(max_value, sum);
for (int mid_id = threadIdx.y; mid_id < mid_dim; mid_id += blockDim.y) {
int data_offset = input_offset + mid_id * mid_stride;
output[data_offset] = functor(static_cast<AccT>(input[data_offset]));
}
}
}
}
template <typename T, bool LogMode = false>
void LaunchNormalSoftmaxForward(const platform::CUDADeviceContext& dev_ctx,
T* output_data, const T* input_data,
int high_dim, int mid_dim, int low_dim) {
using AccT = typename details::MPTypeTrait<T>::Type;
dim3 grid, block;
GetLaunchConfig(high_dim, mid_dim, low_dim, &grid, &block);
if (LogMode) {
NormalSoftmaxForward<
T, AccT,
LogSoftmaxForwardFunctor><<<grid, block, 0, dev_ctx.stream()>>>(
output_data, input_data, high_dim, mid_dim, low_dim);
} else {
NormalSoftmaxForward<
T, AccT, SoftmaxForwardFunctor><<<grid, block, 0, dev_ctx.stream()>>>(
output_data, input_data, high_dim, mid_dim, low_dim);
}
}
template <typename T, bool LogMode = false>
void SoftmaxForwardCUDAKernelDriver(const platform::CUDADeviceContext& dev_ctx,
const Tensor& x, const int input_axis,
......@@ -490,6 +650,9 @@ void SoftmaxForwardCUDAKernelDriver(const platform::CUDADeviceContext& dev_ctx,
out_data, x.data<T>(), N, dim,
dim, kDimLog2);
}
} else if (D > 1) {
LaunchNormalSoftmaxForward<T, LogMode>(dev_ctx, out_data, x.data<T>(), N,
dim, D);
} else {
ScopedTensorDescriptor desc;
std::vector<int> tensor_dims = {N, dim, D, 1};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册