From 45078d9fdba53561a38027820d72625904615051 Mon Sep 17 00:00:00 2001 From: Zhang Zheng <32410583+ZzSean@users.noreply.github.com> Date: Wed, 30 Mar 2022 11:23:11 +0800 Subject: [PATCH] Optimize the perf of top_k when k is too large (#40941) * Optimize the perf of top_k when k is too large * fix rcom compile * fix * only compile in cuda * fix log info --- paddle/fluid/operators/top_k_function_cuda.h | 423 +++++++++++++++++++ paddle/phi/kernels/gpu/top_k_kernel.cu | 69 ++- 2 files changed, 489 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/operators/top_k_function_cuda.h b/paddle/fluid/operators/top_k_function_cuda.h index 80c9935057c..a3da3d572e5 100644 --- a/paddle/fluid/operators/top_k_function_cuda.h +++ b/paddle/fluid/operators/top_k_function_cuda.h @@ -23,8 +23,10 @@ limitations under the License. */ #include #endif #include "paddle/fluid/operators/eigen/eigen_function.h" +#include "paddle/fluid/operators/kernel_primitives/functor_primitives.h" #include "paddle/fluid/operators/top_k_op.h" #include "paddle/fluid/platform/device/gpu/gpu_device_function.h" +#include "paddle/fluid/platform/device/gpu/gpu_primitives.h" #include "paddle/fluid/platform/float16.h" #ifdef __HIPCC__ @@ -358,6 +360,427 @@ __global__ void KeMatrixTopK(T* output, int output_stride, int64_t* indices, } } +/*---------------------------Radix TopK Begin------------------*/ +#if defined(PADDLE_WITH_CUDA) +constexpr int RADIX_BITS = 2; // digits are base-(2 ^ RADIX_BITS) +constexpr int RADIX_SIZE = 4; // 2 ^ RADIX_BITS +constexpr int RADIX_MASK = (RADIX_SIZE - 1); + +/*---------------------------Helper Structs------------------*/ +template +struct Bitfield {}; + +template <> +struct Bitfield { + static __device__ __forceinline__ unsigned int GetBitfield(unsigned int val, + int pos, int len) { + unsigned int ret; + asm("bfe.u32 %0, %1, %2, %3;" : "=r"(ret) : "r"(val), "r"(pos), "r"(len)); + return ret; + } + + static __device__ __forceinline__ unsigned int SetBitfield( + unsigned int val, unsigned int to_insert, int pos, int len) { + unsigned int ret; + asm("bfi.b32 %0, %1, %2, %3, %4;" + : "=r"(ret) + : "r"(to_insert), "r"(val), "r"(pos), "r"(len)); + return ret; + } +}; + +template <> +struct Bitfield { + static __device__ __forceinline__ uint64_t GetBitfield(uint64_t val, int pos, + int len) { + uint64_t ret; + asm("bfe.u64 %0, %1, %2, %3;" : "=l"(ret) : "l"(val), "r"(pos), "r"(len)); + return ret; + } + + static __device__ __forceinline__ uint64_t SetBitfield(uint64_t val, + uint64_t to_insert, + int pos, int len) { + uint64_t ret; + asm("bfi.b64 %0, %1, %2, %3, %4;" + : "=l"(ret) + : "l"(to_insert), "l"(val), "r"(pos), "r"(len)); + return ret; + } +}; + +template +struct RadixTypeConfig {}; + +template <> +struct RadixTypeConfig { + typedef uint32_t RadixType; + + static inline __device__ RadixType Convert(float v) { + RadixType x = __float_as_int(v); + RadixType mask = (x & 0x80000000) ? 0xffffffff : 0x80000000; + + return (v == v) ? (x ^ mask) : 0xffffffff; + } + + static inline __device__ float Deconvert(RadixType v) { + RadixType mask = (v & 0x80000000) ? 0x80000000 : 0xffffffff; + + return __int_as_float(v ^ mask); + } +}; + +template <> +struct RadixTypeConfig { + typedef uint64_t RadixType; + + static inline __device__ RadixType Convert(double v) { + RadixType x = __double_as_longlong(v); + RadixType mask = -((x >> 63)) | 0x8000000000000000; + return (v == v) ? (x ^ mask) : 0xffffffffffffffff; + } + + static inline __device__ double Deconvert(RadixType v) { + RadixType mask = ((v >> 63) - 1) | 0x8000000000000000; + return __longlong_as_double(v ^ mask); + } +}; + +template <> +struct RadixTypeConfig { + typedef uint32_t RadixType; + + static inline __device__ RadixType Convert(int32_t v) { + static_assert(sizeof(int) == 4, ""); + return 2147483648u + v; + } + + static inline __device__ int32_t Deconvert(RadixType v) { + return v - 2147483648u; + } +}; + +template <> +struct RadixTypeConfig { + typedef uint64_t RadixType; + + static inline __device__ RadixType Convert(int64_t v) { + static_assert(sizeof(int64_t) == 8, ""); + return 9223372036854775808ull + v; + } + + static inline __device__ int64_t Deconvert(RadixType v) { + return v - 9223372036854775808ull; + } +}; + +template <> +struct RadixTypeConfig { + typedef uint32_t RadixType; + + static inline __device__ RadixType Convert(platform::float16 v) { + half v_h = v.to_half(); + RadixType x = __half_as_ushort(v_h); + RadixType mask = (x & 0x00008000) ? 0x0000ffff : 0x00008000; + return (v_h == v_h) ? (x ^ mask) : 0xffff; + } + + static inline __device__ platform::float16 Deconvert(RadixType v) { + RadixType mask = (v & 0x00008000) ? 0x00008000 : 0x0000ffff; + return static_cast(__ushort_as_half(v ^ mask)); + } +}; + +/*---------------------------Helper Functions------------------*/ +__device__ __forceinline__ int GetLaneId() { + int lane_id; + asm("mov.s32 %0, %%laneid;" : "=r"(lane_id)); + return lane_id; +} + +__device__ __forceinline__ unsigned GetLaneMaskLe() { + unsigned mask; + asm("mov.u32 %0, %%lanemask_le;" : "=r"(mask)); + return mask; +} + +template +__device__ void InclusiveBinaryPrefixScan(T* shared_mem, bool in, T* out, + Function func) { + T vote = __ballot_sync(__activemask(), in); + T index = __popc(GetLaneMaskLe() & vote); + T carry = __popc(vote); + + int warp = threadIdx.x / 32; + + if (GetLaneId() == 0) { + shared_mem[warp] = carry; + } + + __syncthreads(); + + if (threadIdx.x == 0) { + int current = 0; + for (int i = 0; i < blockDim.x / 32; ++i) { + T v = shared_mem[i]; + shared_mem[i] = func(shared_mem[i], current); + current = func(current, v); + } + } + + __syncthreads(); + + if (warp >= 1) { + index = func(index, shared_mem[warp - 1]); + } + + *out = index; + + if (KillDependency) { + __syncthreads(); + } +} + +template +__device__ void ExclusiveBinaryPrefixScan(T* shared_mem, bool in, T* out, + T* carry, Function func) { + InclusiveBinaryPrefixScan(shared_mem, in, out, func); + + *out -= (T)in; + + *carry = shared_mem[(blockDim.x + 31) / 32 - 1]; + + if (KillDependency) { + __syncthreads(); + } +} + +template +__device__ T FindPattern(const T* input, T* shared_mem, int slice_size, + RadixType desired, RadixType desired_mask) { + if (threadIdx.x < 2) { + shared_mem[threadIdx.x] = static_cast(0); + } + __syncthreads(); + + int block_dim = static_cast(blockDim.x); + int loop = ((slice_size + block_dim - 1) / block_dim * block_dim); + for (int i = threadIdx.x; i < loop; i += blockDim.x) { + bool valid = (i < slice_size); + T v = valid ? input[i] : static_cast(0); + + if (valid && ((RadixTypeConfig::Convert(v) & desired_mask) == desired)) { + shared_mem[0] = static_cast(1); + shared_mem[1] = v; + } + + __syncthreads(); + + T found = shared_mem[0]; + T val = shared_mem[1]; + + __syncthreads(); + + if (found != static_cast(0)) { + return val; + } + } + + assert(false); + return static_cast(0); +} + +template +__device__ void RadixCountUsingMask(const T* input, int counts[RadixSize], + int* shared_mem, RadixType desired, + RadixType desired_mask, int radix_digit_pos, + int slice_size) { +#pragma unroll + for (int i = 0; i < RadixSize; ++i) { + counts[i] = 0; + } + + if (threadIdx.x < RadixSize) { + shared_mem[threadIdx.x] = 0; + } + __syncthreads(); + + for (int i = threadIdx.x; i < slice_size; i += blockDim.x) { + RadixType val = RadixTypeConfig::Convert(input[i]); + + bool has_val = ((val & desired_mask) == desired); + RadixType digit_in_radix = + Bitfield::GetBitfield(val, radix_digit_pos, RadixBits); + +#pragma unroll + for (uint32_t j = 0; j < RadixSize; ++j) { + bool vote = has_val && (digit_in_radix == j); + counts[j] += __popc(__ballot_sync(__activemask(), vote)); + } + } + + if (GetLaneId() == 0) { +#pragma unroll + for (uint32_t i = 0; i < RadixSize; ++i) { + platform::CudaAtomicAdd(&shared_mem[i], counts[i]); + } + } + + __syncthreads(); + +#pragma unroll + for (uint32_t i = 0; i < RadixSize; ++i) { + counts[i] = shared_mem[i]; + } + + __syncthreads(); +} + +template +__device__ void RadixSearch(const T* input, int k, int slice_size, + int* shared_mem, T* kth_value) { + int counts[RADIX_SIZE]; + + RadixType desired = 0; + RadixType desired_mask = 0; + + int k_left = k; + +#pragma unroll + for (int digit_pos = sizeof(T) * 8 - RADIX_BITS; digit_pos >= 0; + digit_pos -= RADIX_BITS) { + RadixCountUsingMask( + input, counts, shared_mem, desired, desired_mask, digit_pos, + slice_size); + + auto found_unique = [&](int i, int count) -> bool { + if (count == 1 && k_left == 1) { + desired = + Bitfield::SetBitfield(desired, i, digit_pos, RADIX_BITS); + desired_mask = Bitfield::SetBitfield( + desired_mask, RADIX_MASK, digit_pos, RADIX_BITS); + + *kth_value = + FindPattern(input, reinterpret_cast(shared_mem), + slice_size, desired, desired_mask); + return true; + } + return false; + }; + auto found_non_unique = [&](int i, int count) -> bool { + if (count >= k_left) { + desired = + Bitfield::SetBitfield(desired, i, digit_pos, RADIX_BITS); + desired_mask = Bitfield::SetBitfield( + desired_mask, RADIX_MASK, digit_pos, RADIX_BITS); + + return true; + } + k_left -= count; + return false; + }; + + if (Largest) { +// Descending order +#pragma unroll + for (int i = RADIX_SIZE - 1; i >= 0; --i) { + int count = counts[i]; + if (found_unique(i, count)) { + return; + } + if (found_non_unique(i, count)) { + break; + } + } + } else { +// Ascending order +#pragma unroll + for (int i = 0; i < RADIX_SIZE; ++i) { + int count = counts[i]; + if (found_unique(i, count)) { + return; + } + if (found_non_unique(i, count)) { + break; + } + } + } + } + + *kth_value = RadixTypeConfig::Deconvert(desired); +} + +template +__global__ void RadixTopK(const T* input, int k, int slice_num, int slice_size, + T* output, int64_t* indices) { + namespace kps = paddle::operators::kernel_primitives; + __shared__ int shared_mem[32]; + + // 1. Find the k-th value + T kth_value = static_cast(0); + RadixSearch::RadixType, Largest>( + input, k, slice_size, shared_mem, &kth_value); + const auto converted_kth_value = RadixTypeConfig::Convert(kth_value); + + // 2. Select the value strictly less/greater than kth_value and their indices + int block_dim = static_cast(blockDim.x); + int loop = ((slice_size + block_dim - 1) / block_dim * block_dim); + int write_start = 0; + + for (int i = threadIdx.x; i < loop; i += blockDim.x) { + bool valid = i < slice_size; + T v = valid ? input[i] : static_cast(0); + const auto convertd_v = RadixTypeConfig::Convert(v); + bool is_top_k; + if (Largest) { + is_top_k = valid && (convertd_v > converted_kth_value); + } else { + is_top_k = valid && (convertd_v < converted_kth_value); + } + + int index; + int carry; + ExclusiveBinaryPrefixScan>( + shared_mem, is_top_k, &index, &carry, kps::AddFunctor()); + if (is_top_k) { + int write_index = write_start + index; + output[write_index] = v; + indices[write_index] = i; + } + write_start += carry; + } + + // 3. Fill the rest with value == kth_value + assert(k >= write_start); + int remain = k - write_start; + for (int i = threadIdx.x; i < loop; i += blockDim.x) { + bool valid = i < slice_size; + T v = valid ? input[i] : static_cast(0); + const auto convertd_v = RadixTypeConfig::Convert(v); + bool is_top_k = valid && (convertd_v == converted_kth_value); + + int index; + int carry; + ExclusiveBinaryPrefixScan>( + shared_mem, is_top_k, &index, &carry, kps::AddFunctor()); + if (is_top_k && index < remain) { + int write_index = write_start + index; + assert(write_index < k); + output[write_index] = v; + indices[write_index] = i; + } + + if (carry >= remain) { + break; + } + + remain -= carry; + write_start += carry; + } +} +#endif +/*---------------------------Radix TopK End------------------*/ + template __global__ void AssignGrad(T* x_grad, const int64_t* indices, const T* out_grad, size_t rows, size_t cols, size_t k) { diff --git a/paddle/phi/kernels/gpu/top_k_kernel.cu b/paddle/phi/kernels/gpu/top_k_kernel.cu index 7f06af7de43..adaf5cc092b 100644 --- a/paddle/phi/kernels/gpu/top_k_kernel.cu +++ b/paddle/phi/kernels/gpu/top_k_kernel.cu @@ -18,6 +18,7 @@ #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/copy_kernel.h" +#include "paddle/phi/kernels/funcs/gather.cu.h" #include "paddle/phi/kernels/funcs/math_function.h" namespace phi { @@ -91,10 +92,72 @@ void TopkKernel(const Context& dev_ctx, // Successed, return. return; } else { - LOG(INFO) << "TopKOP: Some errors happened when use cub sorting, use " + VLOG(4) << "TopKOP: Some errors happened when use cub sorting, use " + "default topk kernel."; + } + } + +#if defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 9000 + if (input_width >= 1024 && input_height == 1) { + // 1. Gather TopK, but without sorting + constexpr int max_num_threads = 1024; + if (largest) { + ops::RadixTopK< + T, + true><<>>( + input_data, + k, + input_height, + input_width, + output_data, + indices_data); + } else { + ops::RadixTopK< + T, + false><<>>( + input_data, + k, + input_height, + input_width, + output_data, + indices_data); + } + // 2. Sort if needed + if (sorted) { + DenseTensor sorted_output; + DenseTensor sorted_indices; + DenseTensor gather_indices; + sorted_output.Resize(out->dims()); + sorted_indices.Resize(indices->dims()); + gather_indices.Resize(indices->dims()); + dev_ctx.template Alloc(&sorted_output); + dev_ctx.template Alloc(&sorted_indices); + dev_ctx.template Alloc(&gather_indices); + auto* ctx = + reinterpret_cast( + &dev_ctx); + if (ops::SortTopk(*ctx, + out, + k, + input_height, + k, + &sorted_output, + &sorted_indices, + largest)) { + funcs::GPUGather( + dev_ctx, *indices, sorted_indices, &gather_indices); + Copy(dev_ctx, gather_indices, indices->place(), false, indices); + Copy(dev_ctx, sorted_output, out->place(), false, out); + return; + } else { + VLOG(4) << "TopKOP: Some errors happened when use cub sorting, use " "default topk kernel."; + } + } else { + return; } } +#endif // NOTE: pass lds and dim same to input width. // NOTE: old matrix implementation of stride is different to eigen. @@ -199,8 +262,8 @@ void TopkKernel(const Context& dev_ctx, ndims, dev_ctx, trans_out, out, trans); return; } else { - LOG(INFO) << "TopKOP: Some errors happened when use cub sorting, use " - "default topk kernel."; + VLOG(4) << "TopKOP: Some errors happened when use cub sorting, use " + "default topk kernel."; } } -- GitLab