From 5dc0a6eb959694097c160cfebd50635cbbf20148 Mon Sep 17 00:00:00 2001 From: AshburnLee <1578034415@qq.com> Date: Wed, 14 Apr 2021 22:06:37 +0800 Subject: [PATCH] Optimize of backward of log_softmax when axis is -1 and dim_size <= 1024 (#32180) --- paddle/fluid/operators/log_softmax_op.cu | 132 +++++++++++++++++++++-- 1 file changed, 126 insertions(+), 6 deletions(-) diff --git a/paddle/fluid/operators/log_softmax_op.cu b/paddle/fluid/operators/log_softmax_op.cu index 9136de38caf..e4fe92c6256 100644 --- a/paddle/fluid/operators/log_softmax_op.cu +++ b/paddle/fluid/operators/log_softmax_op.cu @@ -65,11 +65,6 @@ __global__ void ComputeLogSoftmaxForwardInWarp(T *dst, const T *src, constexpr int warp_iter = near_greater_power_of_two / kernel_warp_size; int batch_id = blockDim.y * blockIdx.x + threadIdx.y; - // set effective_warp_id as 1 when warps do effective work, - // when warps do ineffective work, effective_warp_id remains unchanged. - int effective_warp_id = batch_size - batch_id; - if (effective_warp_id > 1) effective_warp_id = 1; - int thread_in_warp_idx = threadIdx.x; // 1.read data from global memory to registers @@ -77,7 +72,7 @@ __global__ void ComputeLogSoftmaxForwardInWarp(T *dst, const T *src, // set effective_element_count as the num of elements when warps do effective // work // set effective_element_count as 0, when warps do ineffective work - int effective_element_count = (effective_warp_id <= 0) ? 0 : element_count; + int effective_element_count = (batch_id < batch_size) ? element_count : 0; for (int it = 0; it < warp_iter; ++it) { int element_index = thread_in_warp_idx + it * kernel_warp_size; if (element_index < effective_element_count) { @@ -181,6 +176,131 @@ class LogSoftmaxKernel } }; +// Backward below +#define LAUNCH_WARP_BACKWARD_COMPUTE(near_greater_power_of_two) \ + case near_greater_power_of_two: \ + ComputeLogSoftmaxBackwardInWarp< \ + T, AccT, near_greater_power_of_two><<>>( \ + output, grad_output, grad_input, outer_size, dim_size); \ + break; + +template +__global__ void ComputeLogSoftmaxBackwardInWarp(const T *output, + const T *grad_output, + T *grad_input, int batch_size, + int element_count) { + constexpr int near_greater_power_of_two = NearGreaterPowerOfTwo; + constexpr int kernel_warp_size = + (near_greater_power_of_two < 32) ? near_greater_power_of_two : 32; + constexpr int warp_iter = near_greater_power_of_two / kernel_warp_size; + int batch_id = blockDim.y * blockIdx.x + threadIdx.y; + + int thread_in_warp_idx = threadIdx.x % kernel_warp_size; + + // 1.read data from global memory to registers + AccT output_register[warp_iter]; + AccT grad_output_register[warp_iter]; + int effective_element_count = (batch_id < batch_size) ? element_count : 0; + for (int iter = 0; iter < warp_iter; ++iter) { + int element_index = thread_in_warp_idx + iter * kernel_warp_size; + if (element_index < effective_element_count) { + output_register[iter] = + static_cast(output[batch_id * element_count + element_index]); + grad_output_register[iter] = static_cast( + grad_output[batch_id * element_count + element_index]); + } else { + output_register[iter] = AccT(0); + grad_output_register[iter] = AccT(0); + } + } + + // 2. For each warp, accumulate all thread registers + AccT sum = grad_output_register[0]; +#pragma unroll + for (int iter = 1; iter < warp_iter; ++iter) { + sum += grad_output_register[iter]; + } + sum = WarpReduceSum(sum); + +// 3. write result in grad_input +#pragma unroll + for (int iter = 0; iter < warp_iter; ++iter) { + int element_index = thread_in_warp_idx + iter * kernel_warp_size; + if (element_index < element_count) { + grad_input[batch_id * element_count + element_index] = static_cast( + (grad_output_register[iter] - std::exp(output_register[iter]) * sum)); + } + } +} + +template +void LaunchSoftmaxBackwardForLastAxis(T *grad_input, const T *grad_output, + const T *output, int dim_size, + int outer_size, gpuStream_t stream) { + int threads_per_block = 128; + int near_greater_power_of_two = GetNearGreaterPowerOfTwo(dim_size); + int kernel_warp_size = + (near_greater_power_of_two < 32) ? near_greater_power_of_two : 32; + int warps_per_block = (threads_per_block / kernel_warp_size); + int blocks = (outer_size + warps_per_block - 1) / warps_per_block; + dim3 threads(kernel_warp_size, warps_per_block, 1); + + switch (near_greater_power_of_two) { + LAUNCH_WARP_BACKWARD_COMPUTE(1); // dim_size: 1 + LAUNCH_WARP_BACKWARD_COMPUTE(2); // dim_size: 2 + LAUNCH_WARP_BACKWARD_COMPUTE(4); // dim_size: 3~4 + LAUNCH_WARP_BACKWARD_COMPUTE(8); // dim_size: 5~8 + LAUNCH_WARP_BACKWARD_COMPUTE(16); // dim_size: 9~16 + LAUNCH_WARP_BACKWARD_COMPUTE(32); // dim_size: 17~32 + LAUNCH_WARP_BACKWARD_COMPUTE(64); // dim_size: 33~64 + LAUNCH_WARP_BACKWARD_COMPUTE(128); // dim_size: 65~128 + LAUNCH_WARP_BACKWARD_COMPUTE(256); // dim_size: 129~256 + LAUNCH_WARP_BACKWARD_COMPUTE(512); // dim_size: 257~512 + LAUNCH_WARP_BACKWARD_COMPUTE(1024); // dim_size: 513~1024 + + default: + break; + } +} + +template +class LogSoftmaxGradKernel + : public framework::OpKernel { + using MPDType = typename details::MPTypeTrait::Type; + + public: + void Compute(const framework::ExecutionContext &context) const override { + const auto *out = context.Input("Out"); + const auto *g_out = + context.Input(framework::GradVarName("Out")); + auto *g_x = context.Output(framework::GradVarName("X")); + + const auto *out_data = out->data(); + const auto *g_out_data = g_out->data(); + auto *g_x_data = g_x->mutable_data(context.GetPlace()); + + const int rank = out->dims().size(); + const int axis = CanonicalAxis(context.Attr("axis"), rank); + + int dim_size = out->dims()[axis]; + int inner_size = 1; + for (int i = axis + 1; i < out->dims().size(); ++i) { + inner_size *= out->dims()[i]; + } + int outer_size = SizeToAxis(axis, out->dims()); + gpuStream_t stream = context.cuda_device_context().stream(); + + if (inner_size == 1 && dim_size <= 1024 && dim_size * sizeof(T) <= 4096) { + LaunchSoftmaxBackwardForLastAxis( + g_x_data, g_out_data, out_data, dim_size, outer_size, stream); + } else { + LogSoftmaxGradFunctor()( + context.template device_context(), out, + g_out, g_x, axis); + } + } +}; + } // operators } // paddle -- GitLab