diff --git a/paddle/phi/kernels/funcs/axis_utils.h b/paddle/phi/kernels/funcs/axis_utils.h index 02a89471889a7abdda0e9856bf8c8d006895910d..368c4a9e14061cc8628e292f20d3eeea216e30c0 100644 --- a/paddle/phi/kernels/funcs/axis_utils.h +++ b/paddle/phi/kernels/funcs/axis_utils.h @@ -26,24 +26,27 @@ static inline int CanonicalAxis(const int axis, const int rank) { return axis; } -static inline int SizeToAxis(const int axis, DDim dims) { - int size = 1; +template +static inline T SizeToAxis(const int axis, DDim dims) { + T size = 1; for (int i = 0; i < axis; i++) { size *= dims[i]; } return size; } +template static inline int SizeFromAxis(const int axis, DDim dims) { - int size = 1; + T size = 1; for (int i = axis; i < dims.size(); i++) { size *= dims[i]; } return size; } +template static inline int SizeOutAxis(const int axis, DDim dims) { - int size = 1; + T size = 1; for (int i = axis + 1; i < dims.size(); i++) { size *= dims[i]; } diff --git a/paddle/phi/kernels/gpudnn/softmax_gpudnn.h b/paddle/phi/kernels/gpudnn/softmax_gpudnn.h index ffc6a2e3d6f3276a20f4877c1438050e9863f527..a15244cc5260f525c8b7e2d508c31c4a099de81e 100644 --- a/paddle/phi/kernels/gpudnn/softmax_gpudnn.h +++ b/paddle/phi/kernels/gpudnn/softmax_gpudnn.h @@ -258,30 +258,33 @@ api to compute max (sum) in one warp. template __global__ void WarpSoftmaxForward(T* softmax, const T* src, - const int batch_size, - const int stride, - const int element_count) { - constexpr int kDimCeil = 1 << Log2Elements; - constexpr int kWarpSize = (kDimCeil < 32) ? kDimCeil : 32; - constexpr int kVSize = sizeof(VecT) / sizeof(T); - constexpr int kLoops = kDimCeil / kWarpSize; - constexpr int kLoopsV = (kLoops >= kVSize) ? (kLoops / kVSize) : 1; - constexpr int kBatchSize = (kDimCeil <= 32) ? 2 : 1; - int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * kBatchSize; - constexpr int kStep = kBatchSize * kLoopsV * kVSize; - constexpr int kVItem = kLoopsV * kVSize; + const IndexType batch_size, + const IndexType stride, + const IndexType element_count) { + constexpr IndexType kDimCeil = 1 << Log2Elements; + constexpr IndexType kWarpSize = (kDimCeil < 32) ? kDimCeil : 32; + constexpr IndexType kVSize = sizeof(VecT) / sizeof(T); + constexpr IndexType kLoops = kDimCeil / kWarpSize; + constexpr IndexType kLoopsV = (kLoops >= kVSize) ? (kLoops / kVSize) : 1; + constexpr IndexType kBatchSize = (kDimCeil <= 32) ? 2 : 1; + IndexType first_batch = + (static_cast(blockDim.y) * blockIdx.x + threadIdx.y) * + kBatchSize; + constexpr IndexType kStep = kBatchSize * kLoopsV * kVSize; + constexpr IndexType kVItem = kLoopsV * kVSize; constexpr AccT kLowInf = -std::numeric_limits::infinity(); using kMode = kps::details::ReduceMode; // max index to read - int idx_max_v[kBatchSize]; + IndexType idx_max_v[kBatchSize]; #pragma unroll - for (int i = 0; i < kBatchSize; i++) { - int idx_max = ((i + first_batch) < batch_size) ? element_count : 0; + for (IndexType i = 0; i < kBatchSize; i++) { + IndexType idx_max = ((i + first_batch) < batch_size) ? element_count : 0; idx_max_v[i] = idx_max / kVSize; } @@ -307,7 +310,7 @@ __global__ void WarpSoftmaxForward(T* softmax, // read data from global memory #pragma unroll - for (int i = 0; i < kBatchSize; ++i) { + for (IndexType i = 0; i < kBatchSize; ++i) { const VecT* src_v = reinterpret_cast(&src[(first_batch + i) * stride]); VecT* reg_v = reinterpret_cast(&src_data[i][0][0]); @@ -328,7 +331,7 @@ __global__ void WarpSoftmaxForward(T* softmax, // compute sum #pragma unroll - for (int i = 0; i < kBatchSize; ++i) { + for (IndexType i = 0; i < kBatchSize; ++i) { kps::ElementwiseUnary>( &sub_data[i][0][0], &sub_data[i][0][0], UnarySubFunctor(max[i])); kps::ElementwiseUnary>( @@ -344,7 +347,7 @@ __global__ void WarpSoftmaxForward(T* softmax, // write data to global memory #pragma unroll - for (int i = 0; i < kBatchSize; ++i) { + for (IndexType i = 0; i < kBatchSize; ++i) { VecT* softmax_v = reinterpret_cast(&softmax[(first_batch + i) * stride]); VecT* reg_v = reinterpret_cast(&out_tmp[i][0][0]); @@ -489,26 +492,26 @@ __global__ void WarpSoftmaxBackward(T* dst, } } -#define SOFTMAX_WARP_FORWARD_CASE(Log2Elements, AccT) \ - case Log2Elements: \ - WarpSoftmaxForward \ - <<>>( \ - dst, src, batch_size, stride, element_count); \ +#define SOFTMAX_WARP_FORWARD_CASE(Log2Elements, AccT) \ + case Log2Elements: \ + WarpSoftmaxForward \ + <<>>( \ + dst, src, batch_size, stride, element_count); \ break; /* Wrapper of softmax formward with template instantiation on size of input. */ -template -void SwitchWarpSoftmaxForward(const int blocks, +template +void SwitchWarpSoftmaxForward(const IndexType blocks, const dim3 threads, const GPUContext& dev_ctx, T* dst, const T* src, - const int batch_size, - const int stride, - const int element_count, - int Log2Elements) { + const IndexType batch_size, + const IndexType stride, + const IndexType element_count, + IndexType Log2Elements) { using AccT = typename phi::dtype::MPTypeTrait::Type; switch (Log2Elements) { SOFTMAX_WARP_FORWARD_CASE(0, AccT); @@ -758,11 +761,12 @@ void LaunchNormalSoftmaxBackward(const GPUContext& dev_ctx, } } -static std::vector GetSoftmaxTensorDims(const phi::DDim& dims, - const int axis) { - int dim = dims[axis]; - int N = phi::funcs::SizeToAxis(axis, dims); - int D = phi::funcs::SizeOutAxis(axis, dims); +template +static std::vector GetSoftmaxTensorDims(const phi::DDim& dims, + const int axis) { + auto dim = static_cast(dims[axis]); + auto N = phi::funcs::SizeToAxis(axis, dims); + auto D = phi::funcs::SizeOutAxis(axis, dims); return {N, dim, D, 1}; } @@ -950,7 +954,9 @@ inline void LaunchSoftmaxBackwardCudnnKernel( #endif template -bool UseCudnnSoftmax(const GPUContext& ctx, int softmax_dim, bool last_dim) { +bool UseCudnnSoftmax(const GPUContext& ctx, + int64_t softmax_dim, + bool last_dim) { bool cudnn_available = ctx.cudnn_handle(); if (!ctx.cudnn_handle()) { if (std::is_same::value) { @@ -968,24 +974,25 @@ bool UseCudnnSoftmax(const GPUContext& ctx, int softmax_dim, bool last_dim) { } } -template -void SoftmaxForwardCUDAKernelDriver(const GPUContext& dev_ctx, - const DenseTensor& x, - const int input_axis, - DenseTensor* out) { +template +void SoftmaxForwardCUDAKernelDriverImpl(const GPUContext& dev_ctx, + const DenseTensor& x, + const int input_axis, + DenseTensor* out) { auto* out_data = out->data(); int rank = x.dims().size(); int axis = phi::funcs::CanonicalAxis(input_axis, rank); - std::vector tensor_dims = GetSoftmaxTensorDims(x.dims(), axis); - int N = tensor_dims[0]; - int dim = tensor_dims[1]; + std::vector tensor_dims = + GetSoftmaxTensorDims(x.dims(), axis); + IndexType N = tensor_dims[0]; + IndexType dim = tensor_dims[1]; int D = tensor_dims[2]; if (D == 1) { if (!UseCudnnSoftmax(dev_ctx, dim, true)) { int dim_log2 = static_cast(Log2Ceil(dim)); - int dim_ceil = 1 << dim_log2; + IndexType dim_ceil = 1 << dim_log2; int warp_size = (dim_ceil < 32) ? dim_ceil : 32; int batches_per_warp = (dim_ceil <= 32) ? 2 : 1; @@ -994,7 +1001,7 @@ void SoftmaxForwardCUDAKernelDriver(const GPUContext& dev_ctx, int warps_per_block = (threads_per_block / warp_size); int batches_per_block = warps_per_block * batches_per_warp; - int blocks = (N + batches_per_block - 1) / batches_per_block; + IndexType blocks = (N + batches_per_block - 1) / batches_per_block; dim3 threads(warp_size, warps_per_block, 1); // vectorization read/write @@ -1002,35 +1009,35 @@ void SoftmaxForwardCUDAKernelDriver(const GPUContext& dev_ctx, using T2 = typename VecT2::Type; if (dim % 4 == 0) { - SwitchWarpSoftmaxForward(blocks, - threads, - dev_ctx, - out_data, - x.data(), - N, - dim, - dim, - dim_log2); + SwitchWarpSoftmaxForward(blocks, + threads, + dev_ctx, + out_data, + x.data(), + N, + dim, + dim, + dim_log2); } else if (dim % 2 == 0) { - SwitchWarpSoftmaxForward(blocks, - threads, - dev_ctx, - out_data, - x.data(), - N, - dim, - dim, - dim_log2); + SwitchWarpSoftmaxForward(blocks, + threads, + dev_ctx, + out_data, + x.data(), + N, + dim, + dim, + dim_log2); } else { - SwitchWarpSoftmaxForward(blocks, - threads, - dev_ctx, - out_data, - x.data(), - N, - dim, - dim, - dim_log2); + SwitchWarpSoftmaxForward(blocks, + threads, + dev_ctx, + out_data, + x.data(), + N, + dim, + dim, + dim_log2); } } else { LaunchSoftmaxForwardCudnnKernel(dev_ctx, x, axis, LogMode, out); @@ -1041,6 +1048,20 @@ void SoftmaxForwardCUDAKernelDriver(const GPUContext& dev_ctx, } } +template +void SoftmaxForwardCUDAKernelDriver(const GPUContext& dev_ctx, + const DenseTensor& x, + const int input_axis, + DenseTensor* out) { + if (x.numel() >= std::numeric_limits::max()) { + SoftmaxForwardCUDAKernelDriverImpl( + dev_ctx, x, input_axis, out); + } else { + SoftmaxForwardCUDAKernelDriverImpl( + dev_ctx, x, input_axis, out); + } +} + template void SoftmaxBackwardCUDAKernelDriver(const GPUContext& dev_ctx, const DenseTensor& out,