diff --git a/paddle/fluid/operators/softmax_cudnn_op.cu b/paddle/fluid/operators/softmax_cudnn_op.cu index 26d4f7a5e97fb2106dd9ae01d0343d763156e017..ac7963dd8ad437e5646ca56ee7466488c44a504f 100644 --- a/paddle/fluid/operators/softmax_cudnn_op.cu +++ b/paddle/fluid/operators/softmax_cudnn_op.cu @@ -17,6 +17,7 @@ limitations under the License. */ #include "paddle/fluid/operators/softmax_op.h" #include "paddle/fluid/platform/cuda_device_function.h" #include "paddle/fluid/platform/cudnn_helper.h" +#include "paddle/fluid/platform/gpu_launch_config.h" namespace paddle { namespace platform { @@ -39,6 +40,13 @@ using Tensor = framework::Tensor; out_data, x->data(), N, dim, dim); \ break; +#define LAUNCH_SOFTMAX_WARP_BACKWARD(Log2Elements) \ + case Log2Elements: \ + softmax_warp_backward<<< \ + blocks, threads, 0, ctx.cuda_device_context().stream()>>>( \ + dx_data, mul_grad.data(), out->data(), N, dim, dim); \ + break; + static inline int SizeOutAxis(const int axis, DDim dims) { int size = 1; for (int i = axis + 1; i < dims.size(); i++) { @@ -199,6 +207,83 @@ __global__ void WarpSoftmaxForward(T* dst, const T* src, const int batch_size, } } +template +__global__ void softmax_warp_backward(T* gradInput, const T* grad, + const T* output, int batch_size, + int stride, int element_count) { + constexpr int next_power_of_two = 1 << Log2Elements; + constexpr int warp_size_softmax = + (next_power_of_two < 32) ? next_power_of_two : 32; + constexpr int WARP_ITERATIONS = next_power_of_two / warp_size_softmax; + constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1; + + int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * WARP_BATCH; + + int local_batches = batch_size - first_batch; + if (local_batches > WARP_BATCH) { + local_batches = WARP_BATCH; + } + + int local_idx = threadIdx.x % warp_size_softmax; + + int thread_offset = first_batch * stride + local_idx; + grad += thread_offset; + output += thread_offset; + gradInput += thread_offset; + + // load data from global memory + AccT grad_reg[WARP_BATCH][WARP_ITERATIONS]; + AccT output_reg[WARP_BATCH][WARP_ITERATIONS]; + for (int i = 0; i < WARP_BATCH; ++i) { + int batch_element_count = (i >= local_batches) ? 0 : element_count; + for (int it = 0; it < WARP_ITERATIONS; ++it) { + int element_index = local_idx + it * warp_size_softmax; + if (element_index < batch_element_count) { + grad_reg[i][it] = + static_cast(grad[i * element_count + it * warp_size_softmax]); + output_reg[i][it] = static_cast( + output[i * element_count + it * warp_size_softmax]); + } else { + grad_reg[i][it] = AccT(0); + output_reg[i][it] = AccT(0); + } + } + } + + AccT sum[WARP_BATCH]; +#pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + sum[i] = grad_reg[i][0]; +#pragma unroll + for (int it = 1; it < WARP_ITERATIONS; ++it) { + sum[i] += grad_reg[i][it]; + } + } + warp_reduce_sum(sum); + +// store result +#pragma unroll + for (int i = 0; i < WARP_BATCH; ++i) { + if (i >= local_batches) break; +#pragma unroll + for (int it = 0; it < WARP_ITERATIONS; ++it) { + int element_index = local_idx + it * warp_size_softmax; + if (element_index < element_count) { + // compute gradients + gradInput[i * element_count + it * warp_size_softmax] = + (grad_reg[i][it] - output_reg[i][it] * sum[i]); + } + } + } +} + +template +__global__ void MultiplyCUDAKernel(T* C, const T* A, const T* B, int N) { + CUDA_KERNEL_LOOP(i, N) { + C[i] = static_cast(static_cast(A[i]) * static_cast(B[i])); + } +} + template __global__ void VecSoftmaxBackward(T* dst, const T* grad, const T* src, const int batch_size, @@ -340,28 +425,74 @@ class SoftmaxGradCUDNNKernel : public framework::OpKernel { constexpr bool warp_softmax_available = std::is_same::value || std::is_same::value; - if (D == 1 && dim == 128 && N % warps_per_block == 0 && - warp_softmax_available) { - if (std::is_same::value) { - VecSoftmaxBackward< - float, 4, - warps_per_block><<>>( - dx->data(), dout->data(), out->data(), N, dim); - } else if (std::is_same::value) { - VecSoftmaxBackward< - platform::float16, 4, - warps_per_block><<>>( - dx->data(), dout->data(), - out->data(), N, dim); - } else { - PADDLE_ENFORCE_EQ( - warp_softmax_available, true, - platform::errors::Unimplemented( - "Warp softmax backward is only available for fp32 and fp16")); + bool optimize = false; + if (D == 1 && warp_softmax_available) { + if (dim == 128 && N % warps_per_block == 0) { + optimize = true; + if (std::is_same::value) { + VecSoftmaxBackward<<< + N / warps_per_block, warps_per_block * WARP_SIZE, 0, + ctx.cuda_device_context().stream()>>>(dx->data(), + dout->data(), + out->data(), N, dim); + } else if (std::is_same::value) { + VecSoftmaxBackward<<< + N / warps_per_block, warps_per_block * WARP_SIZE, 0, + ctx.cuda_device_context().stream()>>>( + dx->data(), dout->data(), + out->data(), N, dim); + } else { + PADDLE_ENFORCE_EQ( + warp_softmax_available, true, + platform::errors::Unimplemented( + "Warp softmax backward is only available for fp32 and fp16")); + } + } else if (dim < 40 && dim % 32 != 0) { + optimize = true; + Tensor mul_grad; + int numel = N * dim; + mul_grad.mutable_data({numel}, ctx.GetPlace()); + + auto stream = ctx.cuda_device_context().stream(); + auto& dev_ctx = + ctx.template device_context(); + auto config = GetGpuLaunchConfig1D(dev_ctx, numel); + + MultiplyCUDAKernel<<>>( + mul_grad.data(), dout->data(), out->data(), numel); + + int log2_elements = log2_ceil(dim); + const int next_power_of_two = 1 << log2_elements; + + int warp_size = (next_power_of_two < 32) ? next_power_of_two : 32; + + int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; + + constexpr int threads_per_block = 128; + + int warps_per_block = (threads_per_block / warp_size); + int batches_per_block = warps_per_block * batches_per_warp; + int blocks = (N + batches_per_block - 1) / batches_per_block; + dim3 threads(warp_size, warps_per_block, 1); + + switch (log2_elements) { + LAUNCH_SOFTMAX_WARP_BACKWARD(0); // 1 + LAUNCH_SOFTMAX_WARP_BACKWARD(1); // 2 + LAUNCH_SOFTMAX_WARP_BACKWARD(2); // 4 + LAUNCH_SOFTMAX_WARP_BACKWARD(3); // 8 + LAUNCH_SOFTMAX_WARP_BACKWARD(4); // 16 + LAUNCH_SOFTMAX_WARP_BACKWARD(5); // 32 + LAUNCH_SOFTMAX_WARP_BACKWARD(6); // 64 + LAUNCH_SOFTMAX_WARP_BACKWARD(7); // 128 + LAUNCH_SOFTMAX_WARP_BACKWARD(8); // 256 + LAUNCH_SOFTMAX_WARP_BACKWARD(9); // 512 + default: + break; + } } - } else { + } + if (!optimize) { ScopedTensorDescriptor desc; std::vector tensor_dims = {N, dim, D, 1}; DataLayout layout = DataLayout::kNCHW;