未验证 提交 38edea9a 编写于 作者: S sneaxiy 提交者: GitHub

Fix softmax op when the input shape is larger than INT32_MAX (#45897)

* fix softmax int64

* follow comments
上级 bd8f998b
......@@ -26,24 +26,27 @@ static inline int CanonicalAxis(const int axis, const int rank) {
return axis;
}
static inline int SizeToAxis(const int axis, DDim dims) {
int size = 1;
template <typename T = int>
static inline T SizeToAxis(const int axis, DDim dims) {
T size = 1;
for (int i = 0; i < axis; i++) {
size *= dims[i];
}
return size;
}
template <typename T = int>
static inline int SizeFromAxis(const int axis, DDim dims) {
int size = 1;
T size = 1;
for (int i = axis; i < dims.size(); i++) {
size *= dims[i];
}
return size;
}
template <typename T = int>
static inline int SizeOutAxis(const int axis, DDim dims) {
int size = 1;
T size = 1;
for (int i = axis + 1; i < dims.size(); i++) {
size *= dims[i];
}
......
......@@ -258,30 +258,33 @@ api to compute max (sum) in one warp.
template <typename T,
typename VecT,
typename AccT,
typename IndexType,
int Log2Elements,
bool LogMode = false>
__global__ void WarpSoftmaxForward(T* softmax,
const T* src,
const int batch_size,
const int stride,
const int element_count) {
constexpr int kDimCeil = 1 << Log2Elements;
constexpr int kWarpSize = (kDimCeil < 32) ? kDimCeil : 32;
constexpr int kVSize = sizeof(VecT) / sizeof(T);
constexpr int kLoops = kDimCeil / kWarpSize;
constexpr int kLoopsV = (kLoops >= kVSize) ? (kLoops / kVSize) : 1;
constexpr int kBatchSize = (kDimCeil <= 32) ? 2 : 1;
int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * kBatchSize;
constexpr int kStep = kBatchSize * kLoopsV * kVSize;
constexpr int kVItem = kLoopsV * kVSize;
const IndexType batch_size,
const IndexType stride,
const IndexType element_count) {
constexpr IndexType kDimCeil = 1 << Log2Elements;
constexpr IndexType kWarpSize = (kDimCeil < 32) ? kDimCeil : 32;
constexpr IndexType kVSize = sizeof(VecT) / sizeof(T);
constexpr IndexType kLoops = kDimCeil / kWarpSize;
constexpr IndexType kLoopsV = (kLoops >= kVSize) ? (kLoops / kVSize) : 1;
constexpr IndexType kBatchSize = (kDimCeil <= 32) ? 2 : 1;
IndexType first_batch =
(static_cast<IndexType>(blockDim.y) * blockIdx.x + threadIdx.y) *
kBatchSize;
constexpr IndexType kStep = kBatchSize * kLoopsV * kVSize;
constexpr IndexType kVItem = kLoopsV * kVSize;
constexpr AccT kLowInf = -std::numeric_limits<AccT>::infinity();
using kMode = kps::details::ReduceMode;
// max index to read
int idx_max_v[kBatchSize];
IndexType idx_max_v[kBatchSize];
#pragma unroll
for (int i = 0; i < kBatchSize; i++) {
int idx_max = ((i + first_batch) < batch_size) ? element_count : 0;
for (IndexType i = 0; i < kBatchSize; i++) {
IndexType idx_max = ((i + first_batch) < batch_size) ? element_count : 0;
idx_max_v[i] = idx_max / kVSize;
}
......@@ -307,7 +310,7 @@ __global__ void WarpSoftmaxForward(T* softmax,
// read data from global memory
#pragma unroll
for (int i = 0; i < kBatchSize; ++i) {
for (IndexType i = 0; i < kBatchSize; ++i) {
const VecT* src_v =
reinterpret_cast<const VecT*>(&src[(first_batch + i) * stride]);
VecT* reg_v = reinterpret_cast<VecT*>(&src_data[i][0][0]);
......@@ -328,7 +331,7 @@ __global__ void WarpSoftmaxForward(T* softmax,
// compute sum
#pragma unroll
for (int i = 0; i < kBatchSize; ++i) {
for (IndexType i = 0; i < kBatchSize; ++i) {
kps::ElementwiseUnary<AccT, AccT, kVItem, 1, UnarySubFunctor<AccT>>(
&sub_data[i][0][0], &sub_data[i][0][0], UnarySubFunctor<AccT>(max[i]));
kps::ElementwiseUnary<AccT, AccT, kVItem, 1, ExpFunctor<AccT>>(
......@@ -344,7 +347,7 @@ __global__ void WarpSoftmaxForward(T* softmax,
// write data to global memory
#pragma unroll
for (int i = 0; i < kBatchSize; ++i) {
for (IndexType i = 0; i < kBatchSize; ++i) {
VecT* softmax_v =
reinterpret_cast<VecT*>(&softmax[(first_batch + i) * stride]);
VecT* reg_v = reinterpret_cast<VecT*>(&out_tmp[i][0][0]);
......@@ -489,26 +492,26 @@ __global__ void WarpSoftmaxBackward(T* dst,
}
}
#define SOFTMAX_WARP_FORWARD_CASE(Log2Elements, AccT) \
case Log2Elements: \
WarpSoftmaxForward<T, VecT, AccT, Log2Elements, LogMode> \
<<<blocks, threads, 0, dev_ctx.stream()>>>( \
dst, src, batch_size, stride, element_count); \
#define SOFTMAX_WARP_FORWARD_CASE(Log2Elements, AccT) \
case Log2Elements: \
WarpSoftmaxForward<T, VecT, AccT, IndexType, Log2Elements, LogMode> \
<<<blocks, threads, 0, dev_ctx.stream()>>>( \
dst, src, batch_size, stride, element_count); \
break;
/*
Wrapper of softmax formward with template instantiation on size of input.
*/
template <typename T, typename VecT, bool LogMode>
void SwitchWarpSoftmaxForward(const int blocks,
template <typename T, typename VecT, typename IndexType, bool LogMode>
void SwitchWarpSoftmaxForward(const IndexType blocks,
const dim3 threads,
const GPUContext& dev_ctx,
T* dst,
const T* src,
const int batch_size,
const int stride,
const int element_count,
int Log2Elements) {
const IndexType batch_size,
const IndexType stride,
const IndexType element_count,
IndexType Log2Elements) {
using AccT = typename phi::dtype::MPTypeTrait<T>::Type;
switch (Log2Elements) {
SOFTMAX_WARP_FORWARD_CASE(0, AccT);
......@@ -758,11 +761,12 @@ void LaunchNormalSoftmaxBackward(const GPUContext& dev_ctx,
}
}
static std::vector<int> GetSoftmaxTensorDims(const phi::DDim& dims,
const int axis) {
int dim = dims[axis];
int N = phi::funcs::SizeToAxis(axis, dims);
int D = phi::funcs::SizeOutAxis(axis, dims);
template <typename T = int>
static std::vector<T> GetSoftmaxTensorDims(const phi::DDim& dims,
const int axis) {
auto dim = static_cast<T>(dims[axis]);
auto N = phi::funcs::SizeToAxis<T>(axis, dims);
auto D = phi::funcs::SizeOutAxis<T>(axis, dims);
return {N, dim, D, 1};
}
......@@ -950,7 +954,9 @@ inline void LaunchSoftmaxBackwardCudnnKernel<phi::dtype::bfloat16>(
#endif
template <typename T>
bool UseCudnnSoftmax(const GPUContext& ctx, int softmax_dim, bool last_dim) {
bool UseCudnnSoftmax(const GPUContext& ctx,
int64_t softmax_dim,
bool last_dim) {
bool cudnn_available = ctx.cudnn_handle();
if (!ctx.cudnn_handle()) {
if (std::is_same<T, phi::dtype::bfloat16>::value) {
......@@ -968,24 +974,25 @@ bool UseCudnnSoftmax(const GPUContext& ctx, int softmax_dim, bool last_dim) {
}
}
template <typename T, bool LogMode = false>
void SoftmaxForwardCUDAKernelDriver(const GPUContext& dev_ctx,
const DenseTensor& x,
const int input_axis,
DenseTensor* out) {
template <typename T, typename IndexType, bool LogMode = false>
void SoftmaxForwardCUDAKernelDriverImpl(const GPUContext& dev_ctx,
const DenseTensor& x,
const int input_axis,
DenseTensor* out) {
auto* out_data = out->data<T>();
int rank = x.dims().size();
int axis = phi::funcs::CanonicalAxis(input_axis, rank);
std::vector<int> tensor_dims = GetSoftmaxTensorDims(x.dims(), axis);
int N = tensor_dims[0];
int dim = tensor_dims[1];
std::vector<IndexType> tensor_dims =
GetSoftmaxTensorDims<IndexType>(x.dims(), axis);
IndexType N = tensor_dims[0];
IndexType dim = tensor_dims[1];
int D = tensor_dims[2];
if (D == 1) {
if (!UseCudnnSoftmax<T>(dev_ctx, dim, true)) {
int dim_log2 = static_cast<int>(Log2Ceil(dim));
int dim_ceil = 1 << dim_log2;
IndexType dim_ceil = 1 << dim_log2;
int warp_size = (dim_ceil < 32) ? dim_ceil : 32;
int batches_per_warp = (dim_ceil <= 32) ? 2 : 1;
......@@ -994,7 +1001,7 @@ void SoftmaxForwardCUDAKernelDriver(const GPUContext& dev_ctx,
int warps_per_block = (threads_per_block / warp_size);
int batches_per_block = warps_per_block * batches_per_warp;
int blocks = (N + batches_per_block - 1) / batches_per_block;
IndexType blocks = (N + batches_per_block - 1) / batches_per_block;
dim3 threads(warp_size, warps_per_block, 1);
// vectorization read/write
......@@ -1002,35 +1009,35 @@ void SoftmaxForwardCUDAKernelDriver(const GPUContext& dev_ctx,
using T2 = typename VecT2<T>::Type;
if (dim % 4 == 0) {
SwitchWarpSoftmaxForward<T, T4, LogMode>(blocks,
threads,
dev_ctx,
out_data,
x.data<T>(),
N,
dim,
dim,
dim_log2);
SwitchWarpSoftmaxForward<T, T4, IndexType, LogMode>(blocks,
threads,
dev_ctx,
out_data,
x.data<T>(),
N,
dim,
dim,
dim_log2);
} else if (dim % 2 == 0) {
SwitchWarpSoftmaxForward<T, T2, LogMode>(blocks,
threads,
dev_ctx,
out_data,
x.data<T>(),
N,
dim,
dim,
dim_log2);
SwitchWarpSoftmaxForward<T, T2, IndexType, LogMode>(blocks,
threads,
dev_ctx,
out_data,
x.data<T>(),
N,
dim,
dim,
dim_log2);
} else {
SwitchWarpSoftmaxForward<T, T, LogMode>(blocks,
threads,
dev_ctx,
out_data,
x.data<T>(),
N,
dim,
dim,
dim_log2);
SwitchWarpSoftmaxForward<T, T, IndexType, LogMode>(blocks,
threads,
dev_ctx,
out_data,
x.data<T>(),
N,
dim,
dim,
dim_log2);
}
} else {
LaunchSoftmaxForwardCudnnKernel<T>(dev_ctx, x, axis, LogMode, out);
......@@ -1041,6 +1048,20 @@ void SoftmaxForwardCUDAKernelDriver(const GPUContext& dev_ctx,
}
}
template <typename T, bool LogMode = false>
void SoftmaxForwardCUDAKernelDriver(const GPUContext& dev_ctx,
const DenseTensor& x,
const int input_axis,
DenseTensor* out) {
if (x.numel() >= std::numeric_limits<int32_t>::max()) {
SoftmaxForwardCUDAKernelDriverImpl<T, int64_t, LogMode>(
dev_ctx, x, input_axis, out);
} else {
SoftmaxForwardCUDAKernelDriverImpl<T, int32_t, LogMode>(
dev_ctx, x, input_axis, out);
}
}
template <typename T, bool LogMode = false>
void SoftmaxBackwardCUDAKernelDriver(const GPUContext& dev_ctx,
const DenseTensor& out,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册