diff --git a/paddle/fluid/operators/softmax_cudnn_op.cu b/paddle/fluid/operators/softmax_cudnn_op.cu index ece1d57743a0571add5d53a34ced77250b76d3f9..26d4f7a5e97fb2106dd9ae01d0343d763156e017 100644 --- a/paddle/fluid/operators/softmax_cudnn_op.cu +++ b/paddle/fluid/operators/softmax_cudnn_op.cu @@ -15,6 +15,7 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/math/math_cuda_utils.h" #include "paddle/fluid/operators/softmax_op.h" +#include "paddle/fluid/platform/cuda_device_function.h" #include "paddle/fluid/platform/cudnn_helper.h" namespace paddle { @@ -31,6 +32,13 @@ using ScopedTensorDescriptor = platform::ScopedTensorDescriptor; using DataLayout = platform::DataLayout; using Tensor = framework::Tensor; +#define LAUNCH_SOFTMAX_WARP_FORWARD(Log2Elements) \ + case Log2Elements: \ + WarpSoftmaxForward<<< \ + blocks, threads, 0, ctx.cuda_device_context().stream()>>>( \ + out_data, x->data(), N, dim, dim); \ + break; + static inline int SizeOutAxis(const int axis, DDim dims) { int size = 1; for (int i = axis + 1; i < dims.size(); i++) { @@ -39,6 +47,12 @@ static inline int SizeOutAxis(const int axis, DDim dims) { return size; } +int log2_ceil(int value) { + int log2_value = 0; + while ((1 << log2_value) < value) ++log2_value; + return log2_value; +} + template union vec_t { static_assert(sizeof(T) == -1, "vec_t is only available by specialization."); @@ -84,6 +98,107 @@ __global__ void VecSoftmaxForward(T* dst, const T* src, const int batch_size, reinterpret_cast(&dst[offset + idx])[0] = buf; } +template +__device__ __forceinline__ void warp_reduce_sum(T* sum) { +#pragma unroll + for (int offset = WARP_SIZE_SOFTMAX / 2; offset > 0; offset /= 2) { +#pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + T sum_val = platform::CudaShuffleXorSync(0xFFFFFFFF, sum[i], offset); + sum[i] = sum[i] + sum_val; + } + } +} + +template +__device__ __forceinline__ void warp_reduce_max(T* sum) { +#pragma unroll + for (int offset = WARP_SIZE_SOFTMAX / 2; offset > 0; offset /= 2) { +#pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + T max_val = platform::CudaShuffleXorSync(0xFFFFFFFF, sum[i], offset); + sum[i] = max(sum[i], max_val); + } + } +} + +template +__global__ void WarpSoftmaxForward(T* dst, const T* src, const int batch_size, + const int stride, const int element_count) { + constexpr int next_power_of_two = 1 << Log2Elements; + constexpr int warp_size_softmax = + (next_power_of_two < 32) ? next_power_of_two : 32; + constexpr int WARP_ITERATIONS = next_power_of_two / warp_size_softmax; + constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1; + + int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH; + + int local_batches = batch_size - first_batch; + if (local_batches > WARP_BATCH) { + local_batches = WARP_BATCH; + } + + int local_idx = threadIdx.x; + + src += first_batch * stride + local_idx; + dst += first_batch * stride + local_idx; + + // load data from global memory + AccT elements[WARP_BATCH][WARP_ITERATIONS]; + for (int i = 0; i < WARP_BATCH; ++i) { + int batch_element_count = (i >= local_batches) ? 0 : element_count; + for (int it = 0; it < WARP_ITERATIONS; ++it) { + int element_index = local_idx + it * warp_size_softmax; + if (element_index < batch_element_count) { + elements[i][it] = + static_cast(src[i * element_count + it * warp_size_softmax]); + } else { + elements[i][it] = -std::numeric_limits::infinity(); + } + } + } + + // compute max_value + AccT max_value[WARP_BATCH]; +#pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + max_value[i] = elements[i][0]; +#pragma unroll + for (int it = 1; it < WARP_ITERATIONS; ++it) { + max_value[i] = + (max_value[i] > elements[i][it]) ? max_value[i] : elements[i][it]; + } + } + warp_reduce_max(max_value); + + AccT sum[WARP_BATCH]{0.0f}; +#pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { +#pragma unroll + for (int it = 0; it < WARP_ITERATIONS; ++it) { + elements[i][it] = (std::exp((elements[i][it] - max_value[i]))); + sum[i] += elements[i][it]; + } + } + warp_reduce_sum(sum); + +// store result +#pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + if (i >= local_batches) break; +#pragma unroll + for (int it = 0; it < WARP_ITERATIONS; ++it) { + int element_index = local_idx + it * warp_size_softmax; + if (element_index < element_count) { + dst[i * element_count + it * warp_size_softmax] = + elements[i][it] / sum[i]; + } else { + break; + } + } + } +} + template __global__ void VecSoftmaxBackward(T* dst, const T* grad, const T* src, const int batch_size, @@ -130,26 +245,61 @@ class SoftmaxCUDNNKernel : public framework::OpKernel { const int N = SizeToAxis(axis, dims); const int D = SizeOutAxis(axis, dims); + constexpr int max_dim = 320; + bool optimize = false; constexpr int warps_per_block = 4; - if (D == 1 && dim == 128 && N % warps_per_block == 0 && sizeof(T) <= 4) { - // a warp for a batch, 4 elements for a thread, only support the softmax - // dim size = 128 currently - if (sizeof(T) == 2) { - VecSoftmaxForward< - T, int2, 4, - warps_per_block><<>>( - out_data, x->data(), N, dim); - } else if (sizeof(T) == 4) { - VecSoftmaxForward< - T, int4, 4, - warps_per_block><<>>( - out_data, x->data(), N, dim); - } else { - assert(false && "not support"); + if (D == 1 && dim <= max_dim && sizeof(T) <= 4) { + if (dim == 128 && N % warps_per_block == 0) { + optimize = true; + // a warp for a batch, 4 elements for a thread, only support the softmax + // dim size = 128 currently + if (sizeof(T) == 2) { + VecSoftmaxForward<<< + N / warps_per_block, warps_per_block * WARP_SIZE, 0, + ctx.cuda_device_context().stream()>>>(out_data, x->data(), N, + dim); + } else if (sizeof(T) == 4) { + VecSoftmaxForward<<< + N / warps_per_block, warps_per_block * WARP_SIZE, 0, + ctx.cuda_device_context().stream()>>>(out_data, x->data(), N, + dim); + } else { + assert(false && "not support"); + } + } else if (dim < max_dim) { + optimize = true; + int log2_elements = static_cast(log2_ceil(dim)); + const int next_power_of_two = 1 << log2_elements; + + int warp_size = (next_power_of_two < 32) ? next_power_of_two : 32; + + int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; + + // use 128 threads per block to maximimize gpu utilization + constexpr int threads_per_block = 128; + + int warps_per_block = (threads_per_block / warp_size); + int batches_per_block = warps_per_block * batches_per_warp; + int blocks = (N + batches_per_block - 1) / batches_per_block; + dim3 threads(warp_size, warps_per_block, 1); + + switch (log2_elements) { + LAUNCH_SOFTMAX_WARP_FORWARD(0); // 1 + LAUNCH_SOFTMAX_WARP_FORWARD(1); // 2 + LAUNCH_SOFTMAX_WARP_FORWARD(2); // 4 + LAUNCH_SOFTMAX_WARP_FORWARD(3); // 8 + LAUNCH_SOFTMAX_WARP_FORWARD(4); // 16 + LAUNCH_SOFTMAX_WARP_FORWARD(5); // 32 + LAUNCH_SOFTMAX_WARP_FORWARD(6); // 64 + LAUNCH_SOFTMAX_WARP_FORWARD(7); // 128 + LAUNCH_SOFTMAX_WARP_FORWARD(8); // 256 + LAUNCH_SOFTMAX_WARP_FORWARD(9); // 512 + default: + break; + } } - } else { + } + if (!optimize) { ScopedTensorDescriptor desc; std::vector tensor_dims = {N, dim, D, 1}; DataLayout layout = DataLayout::kNCHW;