未验证 提交 9012787f 编写于 作者: C carryyu 提交者: GitHub

Optimize softmax's performance when dim_size >= 100000. (#46535)

上级 7057093e
......@@ -19,6 +19,7 @@ limitations under the License. */
#include "paddle/phi/common/bfloat16.h"
#include "paddle/phi/common/float16.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/kernels/funcs/aligned_vector.h"
#include "paddle/phi/kernels/funcs/axis_utils.h"
#include "paddle/phi/kernels/primitive/kernel_primitives.h"
......@@ -26,6 +27,32 @@ limitations under the License. */
#include "paddle/fluid/platform/device/gpu/gpu_device_function.h"
#include "paddle/fluid/platform/device/gpu/gpu_dnn.h"
#define MATRIX_SOFTMAX_ALIGN_BYTES 16
#define MATRIX_SOFTMAX_THREAHOLD 100000
#define FIXED_BLOCK_DIM_BASE(dim, ...) \
case (dim): { \
constexpr auto kBlockDim = (dim); \
__VA_ARGS__; \
} break
#define FIXED_VEC_SIZE_BASE(vec_size, ...) \
case (vec_size): { \
constexpr auto VecSize = (vec_size); \
__VA_ARGS__; \
} break
#define FIXED_BLOCK_DIM(...) \
FIXED_BLOCK_DIM_BASE(512, ##__VA_ARGS__); \
FIXED_BLOCK_DIM_BASE(256, ##__VA_ARGS__); \
FIXED_BLOCK_DIM_BASE(128, ##__VA_ARGS__); \
FIXED_BLOCK_DIM_BASE(64, ##__VA_ARGS__); \
FIXED_BLOCK_DIM_BASE(32, ##__VA_ARGS__)
#define FIXED_VEC_SIZE(...) \
FIXED_VEC_SIZE_BASE(8, ##__VA_ARGS__); \
FIXED_VEC_SIZE_BASE(4, ##__VA_ARGS__)
namespace phi {
using ScopedTensorDescriptor = paddle::platform::ScopedTensorDescriptor;
......@@ -85,6 +112,20 @@ static inline int Log2Ceil(int value) {
return log2_value;
}
inline int getBlockSize(int vec_size, uint64_t dim_size) {
uint64_t block_size = 1;
uint64_t max_block_size =
std::min(dim_size / vec_size, static_cast<uint64_t>(1024));
if (vec_size > 1) {
max_block_size /= 2;
}
while (block_size < (max_block_size)) block_size *= 2;
block_size = std::max(block_size, static_cast<uint64_t>(32));
return block_size;
}
template <typename T, int BatchSize, int WarpSize>
__device__ __forceinline__ void WarpReduceSum(T* sum) {
#pragma unroll
......@@ -111,6 +152,41 @@ __device__ __forceinline__ void WarpReduceMax(T* sum) {
}
}
template <typename T>
__inline__ __device__ void BlockReduceMax(T* val) {
static __shared__ T shared[32];
int lane = threadIdx.x & 0x1f;
int wid = threadIdx.x >> 5;
WarpReduceMax<T, 1, 32>(val);
if (lane == 0) shared[wid] = *val;
__syncthreads();
int block_span = (blockDim.x + warpSize - 1) >> 5;
*val = (lane < block_span) ? shared[lane] : -1e10f;
WarpReduceMax<T, 1, 32>(val);
}
template <typename T>
__inline__ __device__ void BlockReduceSum(T* val) {
static __shared__ T shared[32];
int lane = threadIdx.x & 0x1f;
int wid = threadIdx.x >> 5;
WarpReduceSum<T, 1, 32>(val);
__syncthreads();
if (lane == 0) shared[wid] = *val;
__syncthreads();
int block_span = (blockDim.x + warpSize - 1) >> 5;
*val = (lane < block_span) ? shared[lane] : static_cast<T>(0.0f);
WarpReduceSum<T, 1, 32>(val);
}
template <typename Tx, typename Ty = Tx>
struct ReduceMaxFunctor {
inline Ty initial() { return -std::numeric_limits<Ty>::infinity(); }
......@@ -120,6 +196,14 @@ struct ReduceMaxFunctor {
}
};
template <typename T, typename AccT>
struct MaxFunctor {
__device__ __forceinline__ AccT operator()(const AccT& max_v,
const T& v) const {
return max(max_v, static_cast<AccT>(v));
}
};
template <typename Tx, typename Ty = Tx>
struct ExpFunctor {
HOSTDEVICE inline Ty operator()(const Tx& x) const {
......@@ -245,6 +329,126 @@ struct LogSoftmaxBackwardFunctor {
Tx sum;
};
template <typename T, typename AccT>
struct SumExpFunctor {
HOSTDEVICE inline SumExpFunctor(AccT v) : max_v(v) {}
HOSTDEVICE inline AccT operator()(AccT sum, T v) const {
return sum + std::exp(static_cast<AccT>(v) - max_v);
}
private:
AccT max_v;
};
template <template <typename, typename> class Reduction,
typename T,
typename AccT,
int VecSize>
__device__ __forceinline__ AccT
ThreadVecReduce(const T* data,
int dim_size,
const Reduction<T, AccT>& functor,
AccT default_value) {
using VecT = phi::AlignedVector<T, VecSize>;
AccT thread_val = default_value;
const int last = dim_size % (VecSize * blockDim.x);
T v[VecSize];
VecT* value = reinterpret_cast<VecT*>(&v);
for (int offset = threadIdx.x; offset * VecSize < dim_size - last;
offset += blockDim.x) {
*value = reinterpret_cast<VecT*>(const_cast<T*>(data))[offset];
#pragma unroll
for (int i = 0; i < VecSize; i++) {
thread_val = functor(thread_val, v[i]);
}
}
for (int offset = dim_size - last + threadIdx.x; offset < dim_size;
offset += blockDim.x) {
thread_val = functor(thread_val, data[offset]);
}
return thread_val;
}
template <template <typename, typename> class Reduction,
typename T,
typename AccT,
int VecSize>
__device__ __forceinline__ void ThreadVecWrite(T* out,
const T* input,
int dim_size,
Reduction<AccT, T> functor) {
using VecT = phi::AlignedVector<T, VecSize>;
const int last = dim_size % (VecSize * blockDim.x);
T in_v[VecSize];
VecT* in_value = reinterpret_cast<VecT*>(&in_v);
T out_v[VecSize];
VecT* out_value = reinterpret_cast<VecT*>(&out_v);
for (int offset = threadIdx.x; offset * VecSize < dim_size - last;
offset += blockDim.x) {
*in_value = reinterpret_cast<VecT*>(const_cast<T*>(input))[offset];
#pragma unroll
for (int i = 0; i < VecSize; i++) {
out_v[i] = functor(static_cast<AccT>(in_v[i]));
}
reinterpret_cast<VecT*>(out)[offset] = *out_value;
}
for (int offset = dim_size - last + threadIdx.x; offset < dim_size;
offset += blockDim.x) {
out[offset] = functor(static_cast<AccT>(input[offset]));
}
}
template <typename T,
typename AccT,
typename IndexType,
int BatchSize,
int VecSize,
bool LogMode = false>
__global__ void KeMatrixSoftmaxForward(T* softmax, const T* src, int dim_size) {
using VecT = phi::AlignedVector<T, VecSize>;
int bid = blockIdx.x;
const T* batch_input = src + bid * dim_size;
T* batch_output = softmax + bid * dim_size;
// get max value
AccT thread_max = ThreadVecReduce<MaxFunctor, T, AccT, VecSize>(
batch_input,
dim_size,
MaxFunctor<T, AccT>(),
std::numeric_limits<AccT>::min());
BlockReduceMax<AccT>(&thread_max);
// get exp value and sum all
AccT thread_exp = ThreadVecReduce<SumExpFunctor, T, AccT, VecSize>(
batch_input,
dim_size,
SumExpFunctor<T, AccT>(thread_max),
static_cast<AccT>(0.));
BlockReduceSum<AccT>(&thread_exp);
// write data to softmax_output according to the LogMode
if (LogMode) {
LogSoftmaxForwardFunctor<AccT, T> reduction(thread_max,
std::log(thread_exp));
ThreadVecWrite<LogSoftmaxForwardFunctor, T, AccT, VecSize>(
batch_output, batch_input, dim_size, reduction);
} else {
SoftmaxForwardFunctor<AccT, T> reduction(thread_max, thread_exp);
ThreadVecWrite<SoftmaxForwardFunctor, T, AccT, VecSize>(
batch_output, batch_input, dim_size, reduction);
}
}
/*
Core function of computing softmax forward for axis=-1.
The computation includes
......@@ -927,6 +1131,30 @@ void LaunchSoftmaxBackwardCudnnKernel(const GPUContext& dev_ctx,
}
}
template <typename T, typename IndexType, bool LogMode>
void LaunchKeMatrixSoftmaxForwardKernel(
const GPUContext& dev_ctx, T* out, const T* input, int N, int dim_size) {
using AccT = typename phi::dtype::MPTypeTrait<T>::Type;
const int vec_size = MATRIX_SOFTMAX_ALIGN_BYTES / sizeof(T);
switch (getBlockSize(vec_size, dim_size)) {
FIXED_BLOCK_DIM(switch (vec_size) {
FIXED_VEC_SIZE(
KeMatrixSoftmaxForward<T,
AccT,
IndexType,
kBlockDim,
VecSize,
LogMode>
<<<N, kBlockDim, 0, dev_ctx.stream()>>>(out, input, dim_size));
default:
break;
});
default:
PADDLE_THROW(
errors::Fatal("the input dim has error in the softmax cuda kernel."));
}
}
#if CUDNN_VERSION < 8100
template <>
inline void LaunchSoftmaxForwardCudnnKernel<phi::dtype::bfloat16>(
......@@ -967,7 +1195,8 @@ bool UseCudnnSoftmax(const GPUContext& ctx,
}
constexpr int max_dim = 512;
if (!cudnn_available || !last_dim ||
(softmax_dim <= max_dim && sizeof(T) <= 4)) {
(softmax_dim <= max_dim && sizeof(T) <= 4) ||
softmax_dim >= MATRIX_SOFTMAX_THREAHOLD) {
return false;
} else {
return true;
......@@ -991,6 +1220,11 @@ void SoftmaxForwardCUDAKernelDriverImpl(const GPUContext& dev_ctx,
if (D == 1) {
if (!UseCudnnSoftmax<T>(dev_ctx, dim, true)) {
if (dim >= MATRIX_SOFTMAX_THREAHOLD) {
LaunchKeMatrixSoftmaxForwardKernel<T, IndexType, LogMode>(
dev_ctx, out_data, x.data<T>(), N, dim);
return;
}
int dim_log2 = static_cast<int>(Log2Ceil(dim));
IndexType dim_ceil = 1 << dim_log2;
int warp_size = (dim_ceil < 32) ? dim_ceil : 32;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册