未验证 提交 45078d9f 编写于 作者: Z Zhang Zheng 提交者: GitHub

Optimize the perf of top_k when k is too large (#40941)

* Optimize the perf of top_k when k is too large

* fix rcom compile

* fix

* only compile in cuda

* fix log info
上级 d951f3af
...@@ -23,8 +23,10 @@ limitations under the License. */ ...@@ -23,8 +23,10 @@ limitations under the License. */
#include <hipcub/hipcub.hpp> #include <hipcub/hipcub.hpp>
#endif #endif
#include "paddle/fluid/operators/eigen/eigen_function.h" #include "paddle/fluid/operators/eigen/eigen_function.h"
#include "paddle/fluid/operators/kernel_primitives/functor_primitives.h"
#include "paddle/fluid/operators/top_k_op.h" #include "paddle/fluid/operators/top_k_op.h"
#include "paddle/fluid/platform/device/gpu/gpu_device_function.h" #include "paddle/fluid/platform/device/gpu/gpu_device_function.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
#include "paddle/fluid/platform/float16.h" #include "paddle/fluid/platform/float16.h"
#ifdef __HIPCC__ #ifdef __HIPCC__
...@@ -358,6 +360,427 @@ __global__ void KeMatrixTopK(T* output, int output_stride, int64_t* indices, ...@@ -358,6 +360,427 @@ __global__ void KeMatrixTopK(T* output, int output_stride, int64_t* indices,
} }
} }
/*---------------------------Radix TopK Begin------------------*/
#if defined(PADDLE_WITH_CUDA)
constexpr int RADIX_BITS = 2; // digits are base-(2 ^ RADIX_BITS)
constexpr int RADIX_SIZE = 4; // 2 ^ RADIX_BITS
constexpr int RADIX_MASK = (RADIX_SIZE - 1);
/*---------------------------Helper Structs------------------*/
template <typename T>
struct Bitfield {};
template <>
struct Bitfield<unsigned int> {
static __device__ __forceinline__ unsigned int GetBitfield(unsigned int val,
int pos, int len) {
unsigned int ret;
asm("bfe.u32 %0, %1, %2, %3;" : "=r"(ret) : "r"(val), "r"(pos), "r"(len));
return ret;
}
static __device__ __forceinline__ unsigned int SetBitfield(
unsigned int val, unsigned int to_insert, int pos, int len) {
unsigned int ret;
asm("bfi.b32 %0, %1, %2, %3, %4;"
: "=r"(ret)
: "r"(to_insert), "r"(val), "r"(pos), "r"(len));
return ret;
}
};
template <>
struct Bitfield<uint64_t> {
static __device__ __forceinline__ uint64_t GetBitfield(uint64_t val, int pos,
int len) {
uint64_t ret;
asm("bfe.u64 %0, %1, %2, %3;" : "=l"(ret) : "l"(val), "r"(pos), "r"(len));
return ret;
}
static __device__ __forceinline__ uint64_t SetBitfield(uint64_t val,
uint64_t to_insert,
int pos, int len) {
uint64_t ret;
asm("bfi.b64 %0, %1, %2, %3, %4;"
: "=l"(ret)
: "l"(to_insert), "l"(val), "r"(pos), "r"(len));
return ret;
}
};
template <typename T>
struct RadixTypeConfig {};
template <>
struct RadixTypeConfig<float> {
typedef uint32_t RadixType;
static inline __device__ RadixType Convert(float v) {
RadixType x = __float_as_int(v);
RadixType mask = (x & 0x80000000) ? 0xffffffff : 0x80000000;
return (v == v) ? (x ^ mask) : 0xffffffff;
}
static inline __device__ float Deconvert(RadixType v) {
RadixType mask = (v & 0x80000000) ? 0x80000000 : 0xffffffff;
return __int_as_float(v ^ mask);
}
};
template <>
struct RadixTypeConfig<double> {
typedef uint64_t RadixType;
static inline __device__ RadixType Convert(double v) {
RadixType x = __double_as_longlong(v);
RadixType mask = -((x >> 63)) | 0x8000000000000000;
return (v == v) ? (x ^ mask) : 0xffffffffffffffff;
}
static inline __device__ double Deconvert(RadixType v) {
RadixType mask = ((v >> 63) - 1) | 0x8000000000000000;
return __longlong_as_double(v ^ mask);
}
};
template <>
struct RadixTypeConfig<int32_t> {
typedef uint32_t RadixType;
static inline __device__ RadixType Convert(int32_t v) {
static_assert(sizeof(int) == 4, "");
return 2147483648u + v;
}
static inline __device__ int32_t Deconvert(RadixType v) {
return v - 2147483648u;
}
};
template <>
struct RadixTypeConfig<int64_t> {
typedef uint64_t RadixType;
static inline __device__ RadixType Convert(int64_t v) {
static_assert(sizeof(int64_t) == 8, "");
return 9223372036854775808ull + v;
}
static inline __device__ int64_t Deconvert(RadixType v) {
return v - 9223372036854775808ull;
}
};
template <>
struct RadixTypeConfig<platform::float16> {
typedef uint32_t RadixType;
static inline __device__ RadixType Convert(platform::float16 v) {
half v_h = v.to_half();
RadixType x = __half_as_ushort(v_h);
RadixType mask = (x & 0x00008000) ? 0x0000ffff : 0x00008000;
return (v_h == v_h) ? (x ^ mask) : 0xffff;
}
static inline __device__ platform::float16 Deconvert(RadixType v) {
RadixType mask = (v & 0x00008000) ? 0x00008000 : 0x0000ffff;
return static_cast<platform::float16>(__ushort_as_half(v ^ mask));
}
};
/*---------------------------Helper Functions------------------*/
__device__ __forceinline__ int GetLaneId() {
int lane_id;
asm("mov.s32 %0, %%laneid;" : "=r"(lane_id));
return lane_id;
}
__device__ __forceinline__ unsigned GetLaneMaskLe() {
unsigned mask;
asm("mov.u32 %0, %%lanemask_le;" : "=r"(mask));
return mask;
}
template <typename T, bool KillDependency, class Function>
__device__ void InclusiveBinaryPrefixScan(T* shared_mem, bool in, T* out,
Function func) {
T vote = __ballot_sync(__activemask(), in);
T index = __popc(GetLaneMaskLe() & vote);
T carry = __popc(vote);
int warp = threadIdx.x / 32;
if (GetLaneId() == 0) {
shared_mem[warp] = carry;
}
__syncthreads();
if (threadIdx.x == 0) {
int current = 0;
for (int i = 0; i < blockDim.x / 32; ++i) {
T v = shared_mem[i];
shared_mem[i] = func(shared_mem[i], current);
current = func(current, v);
}
}
__syncthreads();
if (warp >= 1) {
index = func(index, shared_mem[warp - 1]);
}
*out = index;
if (KillDependency) {
__syncthreads();
}
}
template <typename T, bool KillDependency, class Function>
__device__ void ExclusiveBinaryPrefixScan(T* shared_mem, bool in, T* out,
T* carry, Function func) {
InclusiveBinaryPrefixScan<T, false, Function>(shared_mem, in, out, func);
*out -= (T)in;
*carry = shared_mem[(blockDim.x + 31) / 32 - 1];
if (KillDependency) {
__syncthreads();
}
}
template <typename T, typename RadixType>
__device__ T FindPattern(const T* input, T* shared_mem, int slice_size,
RadixType desired, RadixType desired_mask) {
if (threadIdx.x < 2) {
shared_mem[threadIdx.x] = static_cast<T>(0);
}
__syncthreads();
int block_dim = static_cast<int>(blockDim.x);
int loop = ((slice_size + block_dim - 1) / block_dim * block_dim);
for (int i = threadIdx.x; i < loop; i += blockDim.x) {
bool valid = (i < slice_size);
T v = valid ? input[i] : static_cast<T>(0);
if (valid && ((RadixTypeConfig<T>::Convert(v) & desired_mask) == desired)) {
shared_mem[0] = static_cast<T>(1);
shared_mem[1] = v;
}
__syncthreads();
T found = shared_mem[0];
T val = shared_mem[1];
__syncthreads();
if (found != static_cast<T>(0)) {
return val;
}
}
assert(false);
return static_cast<T>(0);
}
template <typename T, typename RadixType, int RadixSize, int RadixBits>
__device__ void RadixCountUsingMask(const T* input, int counts[RadixSize],
int* shared_mem, RadixType desired,
RadixType desired_mask, int radix_digit_pos,
int slice_size) {
#pragma unroll
for (int i = 0; i < RadixSize; ++i) {
counts[i] = 0;
}
if (threadIdx.x < RadixSize) {
shared_mem[threadIdx.x] = 0;
}
__syncthreads();
for (int i = threadIdx.x; i < slice_size; i += blockDim.x) {
RadixType val = RadixTypeConfig<T>::Convert(input[i]);
bool has_val = ((val & desired_mask) == desired);
RadixType digit_in_radix =
Bitfield<RadixType>::GetBitfield(val, radix_digit_pos, RadixBits);
#pragma unroll
for (uint32_t j = 0; j < RadixSize; ++j) {
bool vote = has_val && (digit_in_radix == j);
counts[j] += __popc(__ballot_sync(__activemask(), vote));
}
}
if (GetLaneId() == 0) {
#pragma unroll
for (uint32_t i = 0; i < RadixSize; ++i) {
platform::CudaAtomicAdd(&shared_mem[i], counts[i]);
}
}
__syncthreads();
#pragma unroll
for (uint32_t i = 0; i < RadixSize; ++i) {
counts[i] = shared_mem[i];
}
__syncthreads();
}
template <typename T, typename RadixType, bool Largest>
__device__ void RadixSearch(const T* input, int k, int slice_size,
int* shared_mem, T* kth_value) {
int counts[RADIX_SIZE];
RadixType desired = 0;
RadixType desired_mask = 0;
int k_left = k;
#pragma unroll
for (int digit_pos = sizeof(T) * 8 - RADIX_BITS; digit_pos >= 0;
digit_pos -= RADIX_BITS) {
RadixCountUsingMask<T, RadixType, RADIX_SIZE, RADIX_BITS>(
input, counts, shared_mem, desired, desired_mask, digit_pos,
slice_size);
auto found_unique = [&](int i, int count) -> bool {
if (count == 1 && k_left == 1) {
desired =
Bitfield<RadixType>::SetBitfield(desired, i, digit_pos, RADIX_BITS);
desired_mask = Bitfield<RadixType>::SetBitfield(
desired_mask, RADIX_MASK, digit_pos, RADIX_BITS);
*kth_value =
FindPattern<T, RadixType>(input, reinterpret_cast<T*>(shared_mem),
slice_size, desired, desired_mask);
return true;
}
return false;
};
auto found_non_unique = [&](int i, int count) -> bool {
if (count >= k_left) {
desired =
Bitfield<RadixType>::SetBitfield(desired, i, digit_pos, RADIX_BITS);
desired_mask = Bitfield<RadixType>::SetBitfield(
desired_mask, RADIX_MASK, digit_pos, RADIX_BITS);
return true;
}
k_left -= count;
return false;
};
if (Largest) {
// Descending order
#pragma unroll
for (int i = RADIX_SIZE - 1; i >= 0; --i) {
int count = counts[i];
if (found_unique(i, count)) {
return;
}
if (found_non_unique(i, count)) {
break;
}
}
} else {
// Ascending order
#pragma unroll
for (int i = 0; i < RADIX_SIZE; ++i) {
int count = counts[i];
if (found_unique(i, count)) {
return;
}
if (found_non_unique(i, count)) {
break;
}
}
}
}
*kth_value = RadixTypeConfig<T>::Deconvert(desired);
}
template <typename T, bool Largest>
__global__ void RadixTopK(const T* input, int k, int slice_num, int slice_size,
T* output, int64_t* indices) {
namespace kps = paddle::operators::kernel_primitives;
__shared__ int shared_mem[32];
// 1. Find the k-th value
T kth_value = static_cast<T>(0);
RadixSearch<T, typename RadixTypeConfig<T>::RadixType, Largest>(
input, k, slice_size, shared_mem, &kth_value);
const auto converted_kth_value = RadixTypeConfig<T>::Convert(kth_value);
// 2. Select the value strictly less/greater than kth_value and their indices
int block_dim = static_cast<int>(blockDim.x);
int loop = ((slice_size + block_dim - 1) / block_dim * block_dim);
int write_start = 0;
for (int i = threadIdx.x; i < loop; i += blockDim.x) {
bool valid = i < slice_size;
T v = valid ? input[i] : static_cast<T>(0);
const auto convertd_v = RadixTypeConfig<T>::Convert(v);
bool is_top_k;
if (Largest) {
is_top_k = valid && (convertd_v > converted_kth_value);
} else {
is_top_k = valid && (convertd_v < converted_kth_value);
}
int index;
int carry;
ExclusiveBinaryPrefixScan<int, true, kps::AddFunctor<int>>(
shared_mem, is_top_k, &index, &carry, kps::AddFunctor<int>());
if (is_top_k) {
int write_index = write_start + index;
output[write_index] = v;
indices[write_index] = i;
}
write_start += carry;
}
// 3. Fill the rest with value == kth_value
assert(k >= write_start);
int remain = k - write_start;
for (int i = threadIdx.x; i < loop; i += blockDim.x) {
bool valid = i < slice_size;
T v = valid ? input[i] : static_cast<T>(0);
const auto convertd_v = RadixTypeConfig<T>::Convert(v);
bool is_top_k = valid && (convertd_v == converted_kth_value);
int index;
int carry;
ExclusiveBinaryPrefixScan<int, true, kps::AddFunctor<int>>(
shared_mem, is_top_k, &index, &carry, kps::AddFunctor<int>());
if (is_top_k && index < remain) {
int write_index = write_start + index;
assert(write_index < k);
output[write_index] = v;
indices[write_index] = i;
}
if (carry >= remain) {
break;
}
remain -= carry;
write_start += carry;
}
}
#endif
/*---------------------------Radix TopK End------------------*/
template <typename T, int MaxLength, int BlockSize> template <typename T, int MaxLength, int BlockSize>
__global__ void AssignGrad(T* x_grad, const int64_t* indices, const T* out_grad, __global__ void AssignGrad(T* x_grad, const int64_t* indices, const T* out_grad,
size_t rows, size_t cols, size_t k) { size_t rows, size_t cols, size_t k) {
......
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
#include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/copy_kernel.h" #include "paddle/phi/kernels/copy_kernel.h"
#include "paddle/phi/kernels/funcs/gather.cu.h"
#include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/funcs/math_function.h"
namespace phi { namespace phi {
...@@ -91,10 +92,72 @@ void TopkKernel(const Context& dev_ctx, ...@@ -91,10 +92,72 @@ void TopkKernel(const Context& dev_ctx,
// Successed, return. // Successed, return.
return; return;
} else { } else {
LOG(INFO) << "TopKOP: Some errors happened when use cub sorting, use " VLOG(4) << "TopKOP: Some errors happened when use cub sorting, use "
"default topk kernel.";
}
}
#if defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 9000
if (input_width >= 1024 && input_height == 1) {
// 1. Gather TopK, but without sorting
constexpr int max_num_threads = 1024;
if (largest) {
ops::RadixTopK<
T,
true><<<input_height, max_num_threads, 0, dev_ctx.stream()>>>(
input_data,
k,
input_height,
input_width,
output_data,
indices_data);
} else {
ops::RadixTopK<
T,
false><<<input_height, max_num_threads, 0, dev_ctx.stream()>>>(
input_data,
k,
input_height,
input_width,
output_data,
indices_data);
}
// 2. Sort if needed
if (sorted) {
DenseTensor sorted_output;
DenseTensor sorted_indices;
DenseTensor gather_indices;
sorted_output.Resize(out->dims());
sorted_indices.Resize(indices->dims());
gather_indices.Resize(indices->dims());
dev_ctx.template Alloc<T>(&sorted_output);
dev_ctx.template Alloc<int64_t>(&sorted_indices);
dev_ctx.template Alloc<int64_t>(&gather_indices);
auto* ctx =
reinterpret_cast<const paddle::platform::CUDADeviceContext*>(
&dev_ctx);
if (ops::SortTopk<T>(*ctx,
out,
k,
input_height,
k,
&sorted_output,
&sorted_indices,
largest)) {
funcs::GPUGather<int64_t, int64_t>(
dev_ctx, *indices, sorted_indices, &gather_indices);
Copy(dev_ctx, gather_indices, indices->place(), false, indices);
Copy(dev_ctx, sorted_output, out->place(), false, out);
return;
} else {
VLOG(4) << "TopKOP: Some errors happened when use cub sorting, use "
"default topk kernel."; "default topk kernel.";
}
} else {
return;
} }
} }
#endif
// NOTE: pass lds and dim same to input width. // NOTE: pass lds and dim same to input width.
// NOTE: old matrix implementation of stride is different to eigen. // NOTE: old matrix implementation of stride is different to eigen.
...@@ -199,8 +262,8 @@ void TopkKernel(const Context& dev_ctx, ...@@ -199,8 +262,8 @@ void TopkKernel(const Context& dev_ctx,
ndims, dev_ctx, trans_out, out, trans); ndims, dev_ctx, trans_out, out, trans);
return; return;
} else { } else {
LOG(INFO) << "TopKOP: Some errors happened when use cub sorting, use " VLOG(4) << "TopKOP: Some errors happened when use cub sorting, use "
"default topk kernel."; "default topk kernel.";
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册