未验证 提交 38edea9a 编写于 作者: S sneaxiy 提交者: GitHub

Fix softmax op when the input shape is larger than INT32_MAX (#45897)

* fix softmax int64

* follow comments
上级 bd8f998b
...@@ -26,24 +26,27 @@ static inline int CanonicalAxis(const int axis, const int rank) { ...@@ -26,24 +26,27 @@ static inline int CanonicalAxis(const int axis, const int rank) {
return axis; return axis;
} }
static inline int SizeToAxis(const int axis, DDim dims) { template <typename T = int>
int size = 1; static inline T SizeToAxis(const int axis, DDim dims) {
T size = 1;
for (int i = 0; i < axis; i++) { for (int i = 0; i < axis; i++) {
size *= dims[i]; size *= dims[i];
} }
return size; return size;
} }
template <typename T = int>
static inline int SizeFromAxis(const int axis, DDim dims) { static inline int SizeFromAxis(const int axis, DDim dims) {
int size = 1; T size = 1;
for (int i = axis; i < dims.size(); i++) { for (int i = axis; i < dims.size(); i++) {
size *= dims[i]; size *= dims[i];
} }
return size; return size;
} }
template <typename T = int>
static inline int SizeOutAxis(const int axis, DDim dims) { static inline int SizeOutAxis(const int axis, DDim dims) {
int size = 1; T size = 1;
for (int i = axis + 1; i < dims.size(); i++) { for (int i = axis + 1; i < dims.size(); i++) {
size *= dims[i]; size *= dims[i];
} }
......
...@@ -258,30 +258,33 @@ api to compute max (sum) in one warp. ...@@ -258,30 +258,33 @@ api to compute max (sum) in one warp.
template <typename T, template <typename T,
typename VecT, typename VecT,
typename AccT, typename AccT,
typename IndexType,
int Log2Elements, int Log2Elements,
bool LogMode = false> bool LogMode = false>
__global__ void WarpSoftmaxForward(T* softmax, __global__ void WarpSoftmaxForward(T* softmax,
const T* src, const T* src,
const int batch_size, const IndexType batch_size,
const int stride, const IndexType stride,
const int element_count) { const IndexType element_count) {
constexpr int kDimCeil = 1 << Log2Elements; constexpr IndexType kDimCeil = 1 << Log2Elements;
constexpr int kWarpSize = (kDimCeil < 32) ? kDimCeil : 32; constexpr IndexType kWarpSize = (kDimCeil < 32) ? kDimCeil : 32;
constexpr int kVSize = sizeof(VecT) / sizeof(T); constexpr IndexType kVSize = sizeof(VecT) / sizeof(T);
constexpr int kLoops = kDimCeil / kWarpSize; constexpr IndexType kLoops = kDimCeil / kWarpSize;
constexpr int kLoopsV = (kLoops >= kVSize) ? (kLoops / kVSize) : 1; constexpr IndexType kLoopsV = (kLoops >= kVSize) ? (kLoops / kVSize) : 1;
constexpr int kBatchSize = (kDimCeil <= 32) ? 2 : 1; constexpr IndexType kBatchSize = (kDimCeil <= 32) ? 2 : 1;
int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * kBatchSize; IndexType first_batch =
constexpr int kStep = kBatchSize * kLoopsV * kVSize; (static_cast<IndexType>(blockDim.y) * blockIdx.x + threadIdx.y) *
constexpr int kVItem = kLoopsV * kVSize; kBatchSize;
constexpr IndexType kStep = kBatchSize * kLoopsV * kVSize;
constexpr IndexType kVItem = kLoopsV * kVSize;
constexpr AccT kLowInf = -std::numeric_limits<AccT>::infinity(); constexpr AccT kLowInf = -std::numeric_limits<AccT>::infinity();
using kMode = kps::details::ReduceMode; using kMode = kps::details::ReduceMode;
// max index to read // max index to read
int idx_max_v[kBatchSize]; IndexType idx_max_v[kBatchSize];
#pragma unroll #pragma unroll
for (int i = 0; i < kBatchSize; i++) { for (IndexType i = 0; i < kBatchSize; i++) {
int idx_max = ((i + first_batch) < batch_size) ? element_count : 0; IndexType idx_max = ((i + first_batch) < batch_size) ? element_count : 0;
idx_max_v[i] = idx_max / kVSize; idx_max_v[i] = idx_max / kVSize;
} }
...@@ -307,7 +310,7 @@ __global__ void WarpSoftmaxForward(T* softmax, ...@@ -307,7 +310,7 @@ __global__ void WarpSoftmaxForward(T* softmax,
// read data from global memory // read data from global memory
#pragma unroll #pragma unroll
for (int i = 0; i < kBatchSize; ++i) { for (IndexType i = 0; i < kBatchSize; ++i) {
const VecT* src_v = const VecT* src_v =
reinterpret_cast<const VecT*>(&src[(first_batch + i) * stride]); reinterpret_cast<const VecT*>(&src[(first_batch + i) * stride]);
VecT* reg_v = reinterpret_cast<VecT*>(&src_data[i][0][0]); VecT* reg_v = reinterpret_cast<VecT*>(&src_data[i][0][0]);
...@@ -328,7 +331,7 @@ __global__ void WarpSoftmaxForward(T* softmax, ...@@ -328,7 +331,7 @@ __global__ void WarpSoftmaxForward(T* softmax,
// compute sum // compute sum
#pragma unroll #pragma unroll
for (int i = 0; i < kBatchSize; ++i) { for (IndexType i = 0; i < kBatchSize; ++i) {
kps::ElementwiseUnary<AccT, AccT, kVItem, 1, UnarySubFunctor<AccT>>( kps::ElementwiseUnary<AccT, AccT, kVItem, 1, UnarySubFunctor<AccT>>(
&sub_data[i][0][0], &sub_data[i][0][0], UnarySubFunctor<AccT>(max[i])); &sub_data[i][0][0], &sub_data[i][0][0], UnarySubFunctor<AccT>(max[i]));
kps::ElementwiseUnary<AccT, AccT, kVItem, 1, ExpFunctor<AccT>>( kps::ElementwiseUnary<AccT, AccT, kVItem, 1, ExpFunctor<AccT>>(
...@@ -344,7 +347,7 @@ __global__ void WarpSoftmaxForward(T* softmax, ...@@ -344,7 +347,7 @@ __global__ void WarpSoftmaxForward(T* softmax,
// write data to global memory // write data to global memory
#pragma unroll #pragma unroll
for (int i = 0; i < kBatchSize; ++i) { for (IndexType i = 0; i < kBatchSize; ++i) {
VecT* softmax_v = VecT* softmax_v =
reinterpret_cast<VecT*>(&softmax[(first_batch + i) * stride]); reinterpret_cast<VecT*>(&softmax[(first_batch + i) * stride]);
VecT* reg_v = reinterpret_cast<VecT*>(&out_tmp[i][0][0]); VecT* reg_v = reinterpret_cast<VecT*>(&out_tmp[i][0][0]);
...@@ -489,26 +492,26 @@ __global__ void WarpSoftmaxBackward(T* dst, ...@@ -489,26 +492,26 @@ __global__ void WarpSoftmaxBackward(T* dst,
} }
} }
#define SOFTMAX_WARP_FORWARD_CASE(Log2Elements, AccT) \ #define SOFTMAX_WARP_FORWARD_CASE(Log2Elements, AccT) \
case Log2Elements: \ case Log2Elements: \
WarpSoftmaxForward<T, VecT, AccT, Log2Elements, LogMode> \ WarpSoftmaxForward<T, VecT, AccT, IndexType, Log2Elements, LogMode> \
<<<blocks, threads, 0, dev_ctx.stream()>>>( \ <<<blocks, threads, 0, dev_ctx.stream()>>>( \
dst, src, batch_size, stride, element_count); \ dst, src, batch_size, stride, element_count); \
break; break;
/* /*
Wrapper of softmax formward with template instantiation on size of input. Wrapper of softmax formward with template instantiation on size of input.
*/ */
template <typename T, typename VecT, bool LogMode> template <typename T, typename VecT, typename IndexType, bool LogMode>
void SwitchWarpSoftmaxForward(const int blocks, void SwitchWarpSoftmaxForward(const IndexType blocks,
const dim3 threads, const dim3 threads,
const GPUContext& dev_ctx, const GPUContext& dev_ctx,
T* dst, T* dst,
const T* src, const T* src,
const int batch_size, const IndexType batch_size,
const int stride, const IndexType stride,
const int element_count, const IndexType element_count,
int Log2Elements) { IndexType Log2Elements) {
using AccT = typename phi::dtype::MPTypeTrait<T>::Type; using AccT = typename phi::dtype::MPTypeTrait<T>::Type;
switch (Log2Elements) { switch (Log2Elements) {
SOFTMAX_WARP_FORWARD_CASE(0, AccT); SOFTMAX_WARP_FORWARD_CASE(0, AccT);
...@@ -758,11 +761,12 @@ void LaunchNormalSoftmaxBackward(const GPUContext& dev_ctx, ...@@ -758,11 +761,12 @@ void LaunchNormalSoftmaxBackward(const GPUContext& dev_ctx,
} }
} }
static std::vector<int> GetSoftmaxTensorDims(const phi::DDim& dims, template <typename T = int>
const int axis) { static std::vector<T> GetSoftmaxTensorDims(const phi::DDim& dims,
int dim = dims[axis]; const int axis) {
int N = phi::funcs::SizeToAxis(axis, dims); auto dim = static_cast<T>(dims[axis]);
int D = phi::funcs::SizeOutAxis(axis, dims); auto N = phi::funcs::SizeToAxis<T>(axis, dims);
auto D = phi::funcs::SizeOutAxis<T>(axis, dims);
return {N, dim, D, 1}; return {N, dim, D, 1};
} }
...@@ -950,7 +954,9 @@ inline void LaunchSoftmaxBackwardCudnnKernel<phi::dtype::bfloat16>( ...@@ -950,7 +954,9 @@ inline void LaunchSoftmaxBackwardCudnnKernel<phi::dtype::bfloat16>(
#endif #endif
template <typename T> template <typename T>
bool UseCudnnSoftmax(const GPUContext& ctx, int softmax_dim, bool last_dim) { bool UseCudnnSoftmax(const GPUContext& ctx,
int64_t softmax_dim,
bool last_dim) {
bool cudnn_available = ctx.cudnn_handle(); bool cudnn_available = ctx.cudnn_handle();
if (!ctx.cudnn_handle()) { if (!ctx.cudnn_handle()) {
if (std::is_same<T, phi::dtype::bfloat16>::value) { if (std::is_same<T, phi::dtype::bfloat16>::value) {
...@@ -968,24 +974,25 @@ bool UseCudnnSoftmax(const GPUContext& ctx, int softmax_dim, bool last_dim) { ...@@ -968,24 +974,25 @@ bool UseCudnnSoftmax(const GPUContext& ctx, int softmax_dim, bool last_dim) {
} }
} }
template <typename T, bool LogMode = false> template <typename T, typename IndexType, bool LogMode = false>
void SoftmaxForwardCUDAKernelDriver(const GPUContext& dev_ctx, void SoftmaxForwardCUDAKernelDriverImpl(const GPUContext& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
const int input_axis, const int input_axis,
DenseTensor* out) { DenseTensor* out) {
auto* out_data = out->data<T>(); auto* out_data = out->data<T>();
int rank = x.dims().size(); int rank = x.dims().size();
int axis = phi::funcs::CanonicalAxis(input_axis, rank); int axis = phi::funcs::CanonicalAxis(input_axis, rank);
std::vector<int> tensor_dims = GetSoftmaxTensorDims(x.dims(), axis); std::vector<IndexType> tensor_dims =
int N = tensor_dims[0]; GetSoftmaxTensorDims<IndexType>(x.dims(), axis);
int dim = tensor_dims[1]; IndexType N = tensor_dims[0];
IndexType dim = tensor_dims[1];
int D = tensor_dims[2]; int D = tensor_dims[2];
if (D == 1) { if (D == 1) {
if (!UseCudnnSoftmax<T>(dev_ctx, dim, true)) { if (!UseCudnnSoftmax<T>(dev_ctx, dim, true)) {
int dim_log2 = static_cast<int>(Log2Ceil(dim)); int dim_log2 = static_cast<int>(Log2Ceil(dim));
int dim_ceil = 1 << dim_log2; IndexType dim_ceil = 1 << dim_log2;
int warp_size = (dim_ceil < 32) ? dim_ceil : 32; int warp_size = (dim_ceil < 32) ? dim_ceil : 32;
int batches_per_warp = (dim_ceil <= 32) ? 2 : 1; int batches_per_warp = (dim_ceil <= 32) ? 2 : 1;
...@@ -994,7 +1001,7 @@ void SoftmaxForwardCUDAKernelDriver(const GPUContext& dev_ctx, ...@@ -994,7 +1001,7 @@ void SoftmaxForwardCUDAKernelDriver(const GPUContext& dev_ctx,
int warps_per_block = (threads_per_block / warp_size); int warps_per_block = (threads_per_block / warp_size);
int batches_per_block = warps_per_block * batches_per_warp; int batches_per_block = warps_per_block * batches_per_warp;
int blocks = (N + batches_per_block - 1) / batches_per_block; IndexType blocks = (N + batches_per_block - 1) / batches_per_block;
dim3 threads(warp_size, warps_per_block, 1); dim3 threads(warp_size, warps_per_block, 1);
// vectorization read/write // vectorization read/write
...@@ -1002,35 +1009,35 @@ void SoftmaxForwardCUDAKernelDriver(const GPUContext& dev_ctx, ...@@ -1002,35 +1009,35 @@ void SoftmaxForwardCUDAKernelDriver(const GPUContext& dev_ctx,
using T2 = typename VecT2<T>::Type; using T2 = typename VecT2<T>::Type;
if (dim % 4 == 0) { if (dim % 4 == 0) {
SwitchWarpSoftmaxForward<T, T4, LogMode>(blocks, SwitchWarpSoftmaxForward<T, T4, IndexType, LogMode>(blocks,
threads, threads,
dev_ctx, dev_ctx,
out_data, out_data,
x.data<T>(), x.data<T>(),
N, N,
dim, dim,
dim, dim,
dim_log2); dim_log2);
} else if (dim % 2 == 0) { } else if (dim % 2 == 0) {
SwitchWarpSoftmaxForward<T, T2, LogMode>(blocks, SwitchWarpSoftmaxForward<T, T2, IndexType, LogMode>(blocks,
threads, threads,
dev_ctx, dev_ctx,
out_data, out_data,
x.data<T>(), x.data<T>(),
N, N,
dim, dim,
dim, dim,
dim_log2); dim_log2);
} else { } else {
SwitchWarpSoftmaxForward<T, T, LogMode>(blocks, SwitchWarpSoftmaxForward<T, T, IndexType, LogMode>(blocks,
threads, threads,
dev_ctx, dev_ctx,
out_data, out_data,
x.data<T>(), x.data<T>(),
N, N,
dim, dim,
dim, dim,
dim_log2); dim_log2);
} }
} else { } else {
LaunchSoftmaxForwardCudnnKernel<T>(dev_ctx, x, axis, LogMode, out); LaunchSoftmaxForwardCudnnKernel<T>(dev_ctx, x, axis, LogMode, out);
...@@ -1041,6 +1048,20 @@ void SoftmaxForwardCUDAKernelDriver(const GPUContext& dev_ctx, ...@@ -1041,6 +1048,20 @@ void SoftmaxForwardCUDAKernelDriver(const GPUContext& dev_ctx,
} }
} }
template <typename T, bool LogMode = false>
void SoftmaxForwardCUDAKernelDriver(const GPUContext& dev_ctx,
const DenseTensor& x,
const int input_axis,
DenseTensor* out) {
if (x.numel() >= std::numeric_limits<int32_t>::max()) {
SoftmaxForwardCUDAKernelDriverImpl<T, int64_t, LogMode>(
dev_ctx, x, input_axis, out);
} else {
SoftmaxForwardCUDAKernelDriverImpl<T, int32_t, LogMode>(
dev_ctx, x, input_axis, out);
}
}
template <typename T, bool LogMode = false> template <typename T, bool LogMode = false>
void SoftmaxBackwardCUDAKernelDriver(const GPUContext& dev_ctx, void SoftmaxBackwardCUDAKernelDriver(const GPUContext& dev_ctx,
const DenseTensor& out, const DenseTensor& out,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册