From 93e1bb9813ba8b0673f60837d75154e55f447418 Mon Sep 17 00:00:00 2001 From: Asthestarsfalll <72954905+Asthestarsfalll@users.noreply.github.com> Date: Mon, 5 Jun 2023 10:32:50 +0800 Subject: [PATCH] optimize logsumexp in small data scale (#52952) * optimize logsumexp in small data scale * fix * fix * add #pragma once * swith to use aligned_vector and support arbitrarily shape * fix store * fix store * refine for special cases * try * fix * update * fix * fix all_reduce * try * fix rocm bug * fix rocm bug * fix rocm bug * fix rocm bug * fix rocm bug * fix rocm bug * fix rocm bug * fix rocm bug --- .../phi/kernels/gpu/logsumexp_function.cu.h | 487 ++++++++++++++++++ paddle/phi/kernels/gpu/logsumexp_kernel.cu | 121 ++++- 2 files changed, 580 insertions(+), 28 deletions(-) create mode 100644 paddle/phi/kernels/gpu/logsumexp_function.cu.h diff --git a/paddle/phi/kernels/gpu/logsumexp_function.cu.h b/paddle/phi/kernels/gpu/logsumexp_function.cu.h new file mode 100644 index 00000000000..53b6fb6d2b2 --- /dev/null +++ b/paddle/phi/kernels/gpu/logsumexp_function.cu.h @@ -0,0 +1,487 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// Licensed under the Apache License, Version 2.0 (the "License"); +// +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include +#include "paddle/phi/kernels/funcs/aligned_vector.h" +#include "paddle/phi/kernels/primitive/functor_primitives.h" + +#define CUDART_INF __longlong_as_double(0x7ff0000000000000ULL) +#define CUDART_INF_F __int_as_float(0x7f800000) + +namespace phi { +namespace funcs { + +constexpr int kWarpSize = 32; + +template +__inline__ __device__ T Inf(); + +template <> +__inline__ __device__ float Inf() { + return CUDART_INF_F; +} + +template <> +__inline__ __device__ double Inf() { + return CUDART_INF; +} + +template + class Functor, + int ThreadGroupWidth = kWarpSize> +__inline__ __device__ T WarpAllReduce(T val) { + for (int mask = ThreadGroupWidth / 2; mask > 0; mask /= 2) { +#if PADDLE_WITH_HIP + val = Functor()(val, __shfl_xor(0xffffffff, val, mask)); +#else + val = Functor()(val, __shfl_xor_sync(0xffffffff, val, mask)); +#endif + } + return val; +} + +#if PADDLE_WITH_HIP +inline void GetNumBlocks(int64_t block_size, + int64_t max_blocks, + int64_t waves, + int* num_blocks) { + int dev; + PADDLE_ENFORCE_GPU_SUCCESS(hipGetDevice(&dev)); + int sm_count; + PADDLE_ENFORCE_GPU_SUCCESS(hipDeviceGetAttribute( + &sm_count, hipDeviceAttributeMultiprocessorCount, dev)); + int tpm; + PADDLE_ENFORCE_GPU_SUCCESS(hipDeviceGetAttribute( + &tpm, hipDeviceAttributeMaxThreadsPerMultiProcessor, dev)); + *num_blocks = std::max( + 1, std::min(max_blocks, sm_count * tpm / block_size * waves)); +} +#else +inline void GetNumBlocks(int64_t block_size, + int64_t max_blocks, + int64_t waves, + int* num_blocks) { + int dev; + PADDLE_ENFORCE_GPU_SUCCESS(cudaGetDevice(&dev)); + int sm_count; + PADDLE_ENFORCE_GPU_SUCCESS( + cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev)); + int tpm; + PADDLE_ENFORCE_GPU_SUCCESS(cudaDeviceGetAttribute( + &tpm, cudaDevAttrMaxThreadsPerMultiProcessor, dev)); + *num_blocks = std::max( + 1, std::min(max_blocks, sm_count * tpm / block_size * waves)); +} +#endif + +template +__global__ void LogsumexpWarpImpl(const Context& dev_ctx, + const int64_t num_row, + const int64_t num_col, + const SourceType* in, + SourceType* out) { + static_assert(ColsPerThread % VecSize == 0, ""); + static_assert(ThreadGroupWidth <= kWarpSize, ""); + static_assert(kWarpSize % ThreadGroupWidth == 0, ""); + constexpr int num_read = ColsPerThread / VecSize; + assert(num_col <= ColsPerThread * ThreadGroupWidth); + const int group_id = blockIdx.x * blockDim.y + threadIdx.y; + const int num_thread_group = gridDim.x * blockDim.y; + const int thread_id = threadIdx.x; + const int step = num_thread_group * RowsPerThread; + + using LoadType = phi::AlignedVector; + using StoreType = phi::AlignedVector; + + LoadType load_vec; + StoreType store_vec; + + T buffer[RowsPerThread][ColsPerThread]; + + for (int64_t cur_row = group_id * RowsPerThread; cur_row < num_row; + cur_row += step) { + T thread_max[RowsPerThread]; +// Read data +#pragma unroll + for (int row_id = 0; row_id < RowsPerThread; row_id++) { + thread_max[row_id] = -Inf(); + T* row_buffer = buffer[row_id]; +#pragma unroll + for (int read_id = 0; read_id < num_read; read_id++) { + const int offset = read_id * VecSize; + const int cur_col = (read_id * ThreadGroupWidth + thread_id) * VecSize; + if (!NeedPadding || cur_col < num_col) { + int64_t load_offset = ((cur_row + row_id) * num_col + cur_col); + phi::Load(in + load_offset, &load_vec); +#pragma unroll + for (int i = 0; i < VecSize; i++) { + row_buffer[offset + i] = static_cast(load_vec[i]); + thread_max[row_id] = + max(thread_max[row_id], row_buffer[offset + i]); + } + } else { +#pragma unroll + for (int i = 0; i < VecSize; i++) { + row_buffer[offset + i] = -Inf(); + } + } + } + } + T warp_max[RowsPerThread]; +// Get warp max +#pragma unroll + for (int row_id = 0; row_id < RowsPerThread; row_id++) { + warp_max[row_id] = WarpAllReduce( + thread_max[row_id]); + } + T thread_sum[RowsPerThread]; +// Calculate +#pragma unroll + for (int row_id = 0; row_id < RowsPerThread; row_id++) { + thread_sum[row_id] = 0; + T* row_buffer = buffer[row_id]; +#pragma unroll + for (int i = 0; i < ColsPerThread; i++) { + thread_sum[row_id] += exp(row_buffer[i] - warp_max[row_id]); + } + } +// Get warp sum and write +#pragma unroll + for (int row_id = 0; row_id < RowsPerThread; row_id++) { + T res = log(WarpAllReduce( + thread_sum[row_id])); + store_vec[row_id] = static_cast(res + warp_max[row_id]); + } + if (thread_id == 0 && cur_row < num_row) { + phi::Store(store_vec, + out + group_id * RowsPerThread); + } + } +} + +template +#if PADDLE_WITH_HIP +inline hipError_t LaunchLogsumexpWarp(const Context& dev_ctx, + const int64_t num_row, + const int64_t num_col, + const SourceType* in, + SourceType* out) { +#else +inline cudaError_t LaunchLogsumexpWarp(const Context& dev_ctx, + const int64_t num_row, + const int64_t num_col, + const SourceType* in, + SourceType* out) { +#endif + constexpr int block_size = 128; + constexpr int waves = 32; + static_assert(block_size % ThreadGroupWidth == 0, ""); + constexpr int thread_groups_per_block = block_size / ThreadGroupWidth; + dim3 block_dim(ThreadGroupWidth, thread_groups_per_block); + const int64_t num_blocks = + (num_row / RowsPerThread + thread_groups_per_block - 1) / + thread_groups_per_block; + int grid_dim_x; + { GetNumBlocks(block_size, num_blocks, waves, &grid_dim_x); } + LogsumexpWarpImpl + <<>>( + dev_ctx, num_row, num_col, in, out); +#if PADDLE_WITH_HIP + return hipPeekAtLastError(); +#else + return cudaPeekAtLastError(); +#endif +} + +template +#if PADDLE_WITH_HIP +inline hipError_t DispatchLogsumexpWarpWithPadding(const Context& dev_ctx, + const int64_t num_row, + const int64_t num_col, + const SourceType* in, + SourceType* out) { +#else +inline cudaError_t DispatchLogsumexpWarpWithPadding(const Context& dev_ctx, + const int64_t num_row, + const int64_t num_col, + const SourceType* in, + SourceType* out) { +#endif + if (num_col == ColsPerThread * ThreadGroupWidth) { + return LaunchLogsumexpWarp(dev_ctx, num_row, num_col, in, out); + } else { + return LaunchLogsumexpWarp(dev_ctx, num_row, num_col, in, out); + } +} + +template +#if PADDLE_WITH_HIP +typename std::enable_if::type +DispatchLogsumexpWarpCols(const Context& dev_ctx, + const int64_t num_row, + const int64_t num_col, + const SourceType* in, + SourceType* out) { +#else +typename std::enable_if::type +DispatchLogsumexpWarpCols(const Context& dev_ctx, + const int64_t num_row, + const int64_t num_col, + const SourceType* in, + SourceType* out) { +#endif + if (num_col <= 0) { +#if PADDLE_WITH_HIP + return hipErrorInvalidValue; +#else + return cudaErrorInvalidValue; +#endif + } +#define HANDLE_THREAD_GROUP(thread_group_width) \ + if (num_col <= (thread_group_width)*VecSize) { \ + if (num_row % 2 == 0) { \ + return DispatchLogsumexpWarpWithPadding( \ + dev_ctx, num_row, num_col, in, out); \ + } else { \ + return DispatchLogsumexpWarpWithPadding( \ + dev_ctx, num_row, num_col, in, out); \ + } \ + } + HANDLE_THREAD_GROUP(1) + HANDLE_THREAD_GROUP(2) + HANDLE_THREAD_GROUP(4) + HANDLE_THREAD_GROUP(8) + HANDLE_THREAD_GROUP(16) + HANDLE_THREAD_GROUP(32) +#undef HANDLE_ROWS +// if num_col > 32 +#define HANDLE_COL(col) \ + if (num_col <= (col)*kWarpSize) { \ + return DispatchLogsumexpWarpWithPadding( \ + dev_ctx, num_row, num_col, in, out); \ + } + + HANDLE_COL(2) + HANDLE_COL(3) + HANDLE_COL(4) + HANDLE_COL(5) + HANDLE_COL(6) + HANDLE_COL(7) + HANDLE_COL(8) + HANDLE_COL(9) + HANDLE_COL(10) + HANDLE_COL(11) + HANDLE_COL(12) + HANDLE_COL(13) + HANDLE_COL(14) + HANDLE_COL(15) + HANDLE_COL(16) + HANDLE_COL(17) + HANDLE_COL(18) + HANDLE_COL(19) + HANDLE_COL(20) + HANDLE_COL(21) + HANDLE_COL(22) + HANDLE_COL(23) + HANDLE_COL(24) + HANDLE_COL(25) + HANDLE_COL(26) + HANDLE_COL(27) + HANDLE_COL(28) + HANDLE_COL(29) + HANDLE_COL(30) + HANDLE_COL(31) + HANDLE_COL(32) +#undef HANDLE_COL +#if PADDLE_WITH_HIP + return hipErrorInvalidValue; +#else + return cudaErrorInvalidValue; +#endif +} + +template +#if PADDLE_WITH_HIP +typename std::enable_if::type +DispatchLogsumexpWarpCols(const Context& dev_ctx, + const int64_t num_row, + const int64_t num_col, + const SourceType* in, + SourceType* out) { +#else +typename std::enable_if::type +DispatchLogsumexpWarpCols(const Context& dev_ctx, + const int64_t num_row, + const int64_t num_col, + const SourceType* in, + SourceType* out) { +#endif + if (num_col <= 0) { +#if PADDLE_WITH_HIP + return hipErrorInvalidValue; +#else + return cudaErrorInvalidValue; +#endif + } +#define HANDLE_THREAD_GROUP(thread_group_width) \ + if (num_col <= (thread_group_width)*VecSize) { \ + if (num_row % 2 == 0) { \ + return DispatchLogsumexpWarpWithPadding( \ + dev_ctx, num_row, num_col, in, out); \ + } else { \ + return DispatchLogsumexpWarpWithPadding( \ + dev_ctx, num_row, num_col, in, out); \ + } \ + } + HANDLE_THREAD_GROUP(1) + HANDLE_THREAD_GROUP(2) + HANDLE_THREAD_GROUP(4) + HANDLE_THREAD_GROUP(8) + HANDLE_THREAD_GROUP(16) + HANDLE_THREAD_GROUP(32) +#undef HANDLE_THREAD_GROUP +// if num_col > 32 +#define HANDLE_COL(col) \ + if (num_col <= (col)*kWarpSize) { \ + return DispatchLogsumexpWarpWithPadding( \ + dev_ctx, num_row, num_col, in, out); \ + } + + HANDLE_COL(4) + HANDLE_COL(6) + HANDLE_COL(8) + HANDLE_COL(10) + HANDLE_COL(12) + HANDLE_COL(14) + HANDLE_COL(16) + HANDLE_COL(18) + HANDLE_COL(20) + HANDLE_COL(22) + HANDLE_COL(24) + HANDLE_COL(26) + HANDLE_COL(28) + HANDLE_COL(30) + HANDLE_COL(32) +#undef HANDLE_COL +#if PADDLE_WITH_HIP + return hipErrorInvalidValue; +#else + return cudaErrorInvalidValue; +#endif +} + +template +#if PADDLE_WITH_HIP +inline hipError_t DispatchLogsumexpWarp(const Context& dev_ctx, + const int64_t num_row, + const int64_t num_col, + const SourceType* in, + SourceType* out) { +#else +inline cudaError_t DispatchLogsumexpWarp(const Context& dev_ctx, + const int64_t num_row, + const int64_t num_col, + const SourceType* in, + SourceType* out) { +#endif + // dispatch logsumexp warp with vecsize + if (num_col % 2 == 0) { + return DispatchLogsumexpWarpCols( + dev_ctx, num_row, num_col, in, out); + } else { + return DispatchLogsumexpWarpCols( + dev_ctx, num_row, num_col, in, out); + } +} +} // namespace funcs +} // namespace phi diff --git a/paddle/phi/kernels/gpu/logsumexp_kernel.cu b/paddle/phi/kernels/gpu/logsumexp_kernel.cu index 7d7dd7ba175..72f878c38dd 100644 --- a/paddle/phi/kernels/gpu/logsumexp_kernel.cu +++ b/paddle/phi/kernels/gpu/logsumexp_kernel.cu @@ -13,6 +13,7 @@ // limitations under the License. #include "paddle/phi/kernels/logsumexp_kernel.h" +#include "paddle/phi/kernels/gpu/logsumexp_function.cu.h" #include "paddle/phi/common/bfloat16.h" #include "paddle/phi/common/float16.h" @@ -21,10 +22,26 @@ #include "paddle/phi/kernels/elementwise_subtract_kernel.h" #include "paddle/phi/kernels/funcs/activation_functor.h" #include "paddle/phi/kernels/funcs/elementwise_base.h" +#include "paddle/phi/kernels/funcs/transpose_function.cu.h" #include "paddle/phi/kernels/gpu/reduce.h" namespace phi { +template +struct ComputeType { + using type = T; +}; + +template <> +struct ComputeType { + using type = float; +}; + +template <> +struct ComputeType { + using type = float; +}; + template struct LogCUDAFunctor { HOSTDEVICE inline T operator()(const T x) const { return std::log(x); } @@ -46,6 +63,44 @@ struct LogCUDAFunctor { } }; +template +void LogsumexpFallbackKernel(const Context& dev_ctx, + const DenseTensor& x, + const std::vector& axis_vec, + const std::vector& outdim_vec, + const std::vector& keeped_outdim_vec, + bool keepdim, + bool reduce_all, + DenseTensor* out) { + auto* in_x = &x; + auto* out_y = out; + + auto outdim = phi::make_ddim(outdim_vec); + auto keeped_outdim = phi::make_ddim(keeped_outdim_vec); + out->Resize(outdim); + dev_ctx.template Alloc(out_y); + + DenseTensor max_x; + max_x.Resize(outdim); + dev_ctx.template Alloc(&max_x); + + phi::funcs::ReduceKernel>( + dev_ctx, *in_x, &max_x, kps::IdentityFunctor(), axis_vec); + + max_x.Resize(keeped_outdim); + DenseTensor temp_x = Subtract(dev_ctx, *in_x, max_x); + phi::funcs::ReduceKernel>( + dev_ctx, temp_x, out_y, kps::ExpFunctor(), axis_vec); + + const std::vector inputs = {out_y}; + std::vector outputs = {&temp_x}; + phi::funcs::ElementwiseKernel( + dev_ctx, inputs, &outputs, LogCUDAFunctor()); + temp_x.Resize(outdim); + out->Resize(outdim); + phi::AddKernel(dev_ctx, temp_x, max_x, out); +} + template void LogsumexpKernel(const Context& dev_ctx, const DenseTensor& x, @@ -53,9 +108,7 @@ void LogsumexpKernel(const Context& dev_ctx, bool keepdim, bool reduce_all, DenseTensor* out) { - auto* in_x = &x; - auto* out_y = out; - auto xdim = in_x->dims(); + auto xdim = x.dims(); for (size_t i = 0; i < xdim.size(); i++) PADDLE_ENFORCE_LT(0, xdim[i], @@ -63,13 +116,15 @@ void LogsumexpKernel(const Context& dev_ctx, "The dims of Input(X) should be greater than 0.")); reduce_all = recompute_reduce_all(x, axis, reduce_all); - std::vector outdim_vec, keeped_outdim_vec; - std::vector axis_vec; + std::vector outdim_vec, keeped_outdim_vec, transpose_shape; + std::vector axis_vec, perm; + int64_t compute_size = 1, other_size = 1; for (auto i : axis) { auto v = i >= 0 ? i : i + xdim.size(); axis_vec.push_back(v); } if (axis.size() == 0 || reduce_all) { + axis_vec.clear(); for (size_t i = 0; i < xdim.size(); i++) { axis_vec.push_back(i); } @@ -83,38 +138,48 @@ void LogsumexpKernel(const Context& dev_ctx, } } if (flag) { + compute_size *= xdim[i]; keeped_outdim_vec.push_back(1); if (keepdim) outdim_vec.push_back(1); } else { + other_size *= xdim[i]; + transpose_shape.push_back(xdim[i]); + perm.push_back(i); outdim_vec.push_back(xdim[i]); keeped_outdim_vec.push_back(xdim[i]); } } auto outdim = phi::make_ddim(outdim_vec); - auto keeped_outdim = phi::make_ddim(keeped_outdim_vec); - out->Resize(outdim); - dev_ctx.template Alloc(out_y); - - DenseTensor max_x; - max_x.Resize(outdim); - dev_ctx.template Alloc(&max_x); - - phi::funcs::ReduceKernel>( - dev_ctx, *in_x, &max_x, kps::IdentityFunctor(), axis_vec); - - max_x.Resize(keeped_outdim); - DenseTensor temp_x = Subtract(dev_ctx, *in_x, max_x); - phi::funcs::ReduceKernel>( - dev_ctx, temp_x, out_y, kps::ExpFunctor(), axis_vec); - - const std::vector inputs = {out_y}; - std::vector outputs = {&temp_x}; - phi::funcs::ElementwiseKernel( - dev_ctx, inputs, &outputs, LogCUDAFunctor()); - temp_x.Resize(outdim); - out->Resize(outdim); - phi::AddKernel(dev_ctx, temp_x, max_x, out); + if (compute_size <= 1024) { + if (perm.size() != xdim.size()) + perm.insert(perm.end(), axis_vec.begin(), axis_vec.end()); + for (auto i : axis_vec) transpose_shape.push_back(xdim[i]); + DenseTensor transpose_x; + if (xdim.size() == 0 || + (axis_vec.size() == 1 && axis_vec[0] == xdim.size())) { + transpose_x = x; + } else { + transpose_x.Resize(make_ddim(transpose_shape)); + dev_ctx.template Alloc(&transpose_x); + phi::funcs::TransposeGPUKernelDriver(dev_ctx, x, perm, &transpose_x); + } + dev_ctx.template Alloc(out); + using compute_type = typename ComputeType::type; + const int64_t num_col = compute_size, num_row = other_size; + funcs::DispatchLogsumexpWarp( + dev_ctx, num_row, num_col, transpose_x.data(), out->data()); + out->Resize(outdim); + } else { + LogsumexpFallbackKernel(dev_ctx, + x, + axis_vec, + outdim_vec, + keeped_outdim_vec, + keepdim, + reduce_all, + out); + } } } // namespace phi -- GitLab