From f8bab5b0671b1ffb86d25c06774bc2d121c2a098 Mon Sep 17 00:00:00 2001 From: AshburnLee <1578034415@qq.com> Date: Sat, 10 Apr 2021 13:01:24 +0800 Subject: [PATCH] Optimize the performance of the forward of log_softmax when axis is -1 and dim <= 1024 (#31630) --- paddle/fluid/operators/log_softmax_op.cu | 170 +++++++++++++++++++++++ 1 file changed, 170 insertions(+) diff --git a/paddle/fluid/operators/log_softmax_op.cu b/paddle/fluid/operators/log_softmax_op.cu index 02fca246d2..9136de38ca 100644 --- a/paddle/fluid/operators/log_softmax_op.cu +++ b/paddle/fluid/operators/log_softmax_op.cu @@ -12,7 +12,177 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include +#include "paddle/fluid/operators/amp/fp16_type_traits.h" #include "paddle/fluid/operators/log_softmax_op.h" +#include "paddle/fluid/platform/cuda_device_function.h" + +namespace paddle { +namespace operators { + +#define LAUNCH_WARP_FORWAR_COMPUTE(near_greater_power_of_two) \ + case near_greater_power_of_two: \ + ComputeLogSoftmaxForwardInWarp< \ + T, AccT, near_greater_power_of_two><<>>( \ + dst, src, outer_size, dim_size); \ + break; + +template +__device__ __forceinline__ T WarpReduceSum(T value) { +#pragma unroll + for (int offset = KernelWarpSize / 2; offset > 0; offset /= 2) { + T sum_val = platform::CudaShuffleXorSync(0xFFFFFFFF, value, offset); + value = value + sum_val; + } + return value; +} + +template +__device__ __forceinline__ T WarpReduceMax(T value) { +#pragma unroll + for (int offset = KernelWarpSize / 2; offset > 0; offset /= 2) { + T max_val = platform::CudaShuffleXorSync(0xFFFFFFFF, value, offset); + value = max(value, max_val); + } + return value; +} + +int GetNearGreaterPowerOfTwo(int value) { + int log2_value = 0; + while ((1 << log2_value) < value) { + ++log2_value; + } + return 1 << log2_value; +} + +template +__global__ void ComputeLogSoftmaxForwardInWarp(T *dst, const T *src, + int batch_size, + int element_count) { + constexpr int near_greater_power_of_two = NearGreaterPowerOfTwo; + constexpr int kernel_warp_size = + (near_greater_power_of_two < 32) ? near_greater_power_of_two : 32; + constexpr int warp_iter = near_greater_power_of_two / kernel_warp_size; + int batch_id = blockDim.y * blockIdx.x + threadIdx.y; + + // set effective_warp_id as 1 when warps do effective work, + // when warps do ineffective work, effective_warp_id remains unchanged. + int effective_warp_id = batch_size - batch_id; + if (effective_warp_id > 1) effective_warp_id = 1; + + int thread_in_warp_idx = threadIdx.x; + + // 1.read data from global memory to registers + AccT elements[warp_iter]; + // set effective_element_count as the num of elements when warps do effective + // work + // set effective_element_count as 0, when warps do ineffective work + int effective_element_count = (effective_warp_id <= 0) ? 0 : element_count; + for (int it = 0; it < warp_iter; ++it) { + int element_index = thread_in_warp_idx + it * kernel_warp_size; + if (element_index < effective_element_count) { + elements[it] = + static_cast(src[batch_id * element_count + element_index]); + } else { + elements[it] = -std::numeric_limits::infinity(); + } + } + + // 2.compute max_value. For each thread, loop all registers to find max + AccT max_value = elements[0]; +#pragma unroll + for (int it = 1; it < warp_iter; ++it) { + max_value = (max_value > elements[it]) ? max_value : elements[it]; + } + max_value = WarpReduceMax(max_value); + + // 3.For each warp, accumulate all thread registers + AccT sum = 0.0f; +#pragma unroll + for (int it = 0; it < warp_iter; ++it) { + sum += std::exp(elements[it] - max_value); + } + sum = WarpReduceSum(sum); + + // 4.store result. + sum = std::log(sum); +#pragma unroll + for (int it = 0; it < warp_iter; ++it) { + int element_index = thread_in_warp_idx + it * kernel_warp_size; + if (element_index < element_count) { + dst[batch_id * element_count + element_index] = + static_cast(elements[it] - max_value - sum); + } else { + break; + } + } +} + +template +void LaunchSoftmaxForwardForLastAxis(T *dst, const T *src, int dim_size, + int outer_size, gpuStream_t stream) { + int threads_per_block = 128; + int near_greater_power_of_two = GetNearGreaterPowerOfTwo(dim_size); + int kernel_warp_size = + (near_greater_power_of_two < 32) ? near_greater_power_of_two : 32; + int warps_per_block = (threads_per_block / kernel_warp_size); + int blocks = (outer_size + warps_per_block - 1) / warps_per_block; + dim3 threads(kernel_warp_size, warps_per_block, 1); + + switch (near_greater_power_of_two) { + LAUNCH_WARP_FORWAR_COMPUTE(1); + LAUNCH_WARP_FORWAR_COMPUTE(2); + LAUNCH_WARP_FORWAR_COMPUTE(4); // dim_size: 3~4 + LAUNCH_WARP_FORWAR_COMPUTE(8); // dim_size: 5~8 + LAUNCH_WARP_FORWAR_COMPUTE(16); // dim_size: 9~16 + LAUNCH_WARP_FORWAR_COMPUTE(32); // dim_size: 17~32 + LAUNCH_WARP_FORWAR_COMPUTE(64); // dim_size: 33~64 + LAUNCH_WARP_FORWAR_COMPUTE(128); // dim_size 65~128 + LAUNCH_WARP_FORWAR_COMPUTE(256); // dim_size 129~256 + LAUNCH_WARP_FORWAR_COMPUTE(512); // dim_size 257~512 + LAUNCH_WARP_FORWAR_COMPUTE(1024); // dim_size 513~1024 + + default: + break; + } +} + +template +class LogSoftmaxKernel + : public framework::OpKernel { + using MPDType = typename details::MPTypeTrait::Type; + + public: + void Compute(const framework::ExecutionContext &context) const override { + const auto *x = context.Input("X"); + auto *out = context.Output("Out"); + const auto *input_data = x->data(); + auto *output_data = out->mutable_data(context.GetPlace()); + + const int rank = x->dims().size(); + const int axis = CanonicalAxis(context.Attr("axis"), rank); + + int dim_size = x->dims()[axis]; + int inner_size = 1; + for (int i = axis + 1; i < x->dims().size(); ++i) { + inner_size *= x->dims()[i]; + } + int outer_size = SizeToAxis(axis, x->dims()); + gpuStream_t stream = context.cuda_device_context().stream(); + + if (inner_size == 1 && dim_size <= 1024 && dim_size * sizeof(T) <= 4096) { + LaunchSoftmaxForwardForLastAxis(output_data, input_data, + dim_size, outer_size, stream); + } else { + LogSoftmaxFunctor()( + context.template device_context(), x, + out, axis); + } + } +}; + +} // operators +} // paddle namespace ops = paddle::operators; namespace plat = paddle::platform; -- GitLab