未验证 提交 e18f5339 编写于 作者: T thunder95 提交者: GitHub

【PaddlePaddle Hackathon 4 No.40】为 Paddle 优化 kthvalue op 在 GPU 上的计算性能 (#51835)

* untracked files

* kthvalue perf

* remove unused files

* fix isnan

* fix isnan2

* fix bug

* try to fix rocm error
上级 7415b101
...@@ -32,6 +32,17 @@ limitations under the License. */ ...@@ -32,6 +32,17 @@ limitations under the License. */
#include "paddle/phi/kernels/primitive/functor_primitives.h" #include "paddle/phi/kernels/primitive/functor_primitives.h"
#define FINAL_MASK 0xffffffff #define FINAL_MASK 0xffffffff
#define WARP_SIZE 32
#define MAX_NUM_THREADS 1024
inline static size_t divide_round_up(size_t n, size_t q) {
return n % q == 0 ? n / q : n / q + 1;
}
inline static size_t round_up(size_t n, size_t q) {
return divide_round_up(n, q) * q;
}
#ifdef __HIPCC__ #ifdef __HIPCC__
namespace rocprim { namespace rocprim {
namespace detail { namespace detail {
...@@ -808,6 +819,61 @@ __device__ void RadixSearch( ...@@ -808,6 +819,61 @@ __device__ void RadixSearch(
*kth_value = RadixTypeConfig<T>::Deconvert(desired); *kth_value = RadixTypeConfig<T>::Deconvert(desired);
} }
template <typename T>
__global__ void GatherKthValue(const T* input,
const int k,
const int64_t num_rows,
const int64_t num_cols,
T* output,
int64_t* indices) {
__shared__ int shared_mem[32];
int row =
blockIdx.z * gridDim.y * gridDim.x + blockIdx.y * gridDim.x + blockIdx.x;
const T* cur_input = input + row * num_cols;
// 1. Find the k-th value
T kth_value = static_cast<T>(0);
RadixSearch<T, RadixTypeConfig<T>::RadixType, false>(
cur_input, k, num_cols, shared_mem, &kth_value);
const auto converted_kth_value = RadixTypeConfig<T>::Convert(kth_value);
// 2. find the k-th index
int64_t kth_index = 0;
bool foundKValue = false;
for (int64_t i = threadIdx.x; i < num_cols; i += blockDim.x) {
bool inRange = (i < num_cols);
T v = inRange ? cur_input[i] : static_cast<T>(0);
bool isKValue =
inRange && ((v == kth_value) || (isnan(static_cast<float>(v)) &&
isnan(static_cast<float>(kth_value))));
if (isKValue) {
kth_index = i;
foundKValue = true;
break;
}
}
if (foundKValue) {
output[row] = kth_value;
indices[row] = kth_index;
}
}
template <typename T>
void LaunchGatherKthValue(const phi::GPUContext& dev_ctx,
const T* input_data,
const int64_t num_cols,
const int64_t num_rows,
const int k,
T* out_data,
int64_t* indices_data) {
int num_threads = std::min(
static_cast<int>(round_up(static_cast<int>(num_cols), WARP_SIZE)),
MAX_NUM_THREADS);
GatherKthValue<T><<<num_rows, num_threads, 0, dev_ctx.stream()>>>(
input_data, k, num_rows, num_cols, out_data, indices_data);
}
template <typename T, bool Largest> template <typename T, bool Largest>
__global__ void RadixTopK(const T* input, __global__ void RadixTopK(const T* input,
int k, int k,
......
...@@ -75,4 +75,5 @@ PD_REGISTER_KERNEL(kthvalue_grad, ...@@ -75,4 +75,5 @@ PD_REGISTER_KERNEL(kthvalue_grad,
float, float,
double, double,
int, int,
int64_t) {} int64_t,
phi::dtype::float16) {}
...@@ -163,7 +163,6 @@ void KthvalueKernel(const Context& dev_ctx, ...@@ -163,7 +163,6 @@ void KthvalueKernel(const Context& dev_ctx,
const auto& in_dims = x.dims(); const auto& in_dims = x.dims();
if (axis < 0) axis += in_dims.size(); if (axis < 0) axis += in_dims.size();
auto out_dims = output->dims(); auto out_dims = output->dims();
const T* input_data = x.data<T>();
T* output_data = dev_ctx.template Alloc<T>(output); T* output_data = dev_ctx.template Alloc<T>(output);
int64_t* indices_data = dev_ctx.template Alloc<int64_t>(indices); int64_t* indices_data = dev_ctx.template Alloc<int64_t>(indices);
...@@ -180,15 +179,28 @@ void KthvalueKernel(const Context& dev_ctx, ...@@ -180,15 +179,28 @@ void KthvalueKernel(const Context& dev_ctx,
phi::funcs::set_constant(dev_ctx, indices, 0); phi::funcs::set_constant(dev_ctx, indices, 0);
return; return;
} }
if (axis == in_dims.size() - 1) { if (axis == in_dims.size() - 1) {
const int64_t& input_height = const int64_t& input_height =
phi::product(phi::slice_ddim(in_dims, 0, in_dims.size() - 1)); phi::product(phi::slice_ddim(in_dims, 0, in_dims.size() - 1));
const int64_t& input_width = in_dims[in_dims.size() - 1]; const int64_t& input_width = in_dims[in_dims.size() - 1];
#if defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 9000
const T* input_data = x.data<T>();
funcs::LaunchGatherKthValue<T>(dev_ctx,
input_data,
input_width,
input_height,
k,
output_data,
indices_data);
#else
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
SortKthvalue<T>( SortKthvalue<T>(
dev_ctx, &x, input_width, input_height, k, output, indices), dev_ctx, &x, input_width, input_height, k, output, indices),
true, true,
phi::errors::External("KthvalueOP: Error when use cub sorting")); phi::errors::External("KthvalueOP: Error when use cub sorting"));
#endif
return; return;
} else { } else {
std::vector<int> trans; std::vector<int> trans;
...@@ -222,18 +234,28 @@ void KthvalueKernel(const Context& dev_ctx, ...@@ -222,18 +234,28 @@ void KthvalueKernel(const Context& dev_ctx,
trans_out_dims[in_dims.size() - 1] = 1; trans_out_dims[in_dims.size() - 1] = 1;
DenseTensor trans_input; DenseTensor trans_input;
trans_input.Resize(trans_dims); trans_input.Resize(trans_dims);
dev_ctx.template Alloc<T>(&trans_input); T* tran_input_data = dev_ctx.template Alloc<T>(&trans_input);
int ndims = trans.size(); int ndims = trans.size();
funcs::TransCompute<phi::GPUContext, T>( funcs::TransCompute<phi::GPUContext, T>(
ndims, dev_ctx, x, &trans_input, trans); ndims, dev_ctx, x, &trans_input, trans);
DenseTensor trans_ind, trans_out; DenseTensor trans_ind, trans_out;
trans_ind.Resize(trans_out_dims); trans_ind.Resize(trans_out_dims);
trans_out.Resize(trans_out_dims); trans_out.Resize(trans_out_dims);
dev_ctx.template Alloc<int64_t>(&trans_ind); int64_t* tran_indices_data = dev_ctx.template Alloc<int64_t>(&trans_ind);
dev_ctx.template Alloc<T>(&trans_out); T* tran_output_data = dev_ctx.template Alloc<T>(&trans_out);
const int64_t input_height = const int64_t input_height =
phi::product(phi::slice_ddim(trans_dims, 0, trans_dims.size() - 1)); phi::product(phi::slice_ddim(trans_dims, 0, trans_dims.size() - 1));
const int64_t input_width = trans_dims[trans_dims.size() - 1]; const int64_t input_width = trans_dims[trans_dims.size() - 1];
#if defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 9000
funcs::LaunchGatherKthValue<T>(dev_ctx,
tran_input_data,
input_width,
input_height,
k,
tran_output_data,
tran_indices_data);
#else
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
SortKthvalue<T>(dev_ctx, SortKthvalue<T>(dev_ctx,
&trans_input, &trans_input,
...@@ -244,6 +266,7 @@ void KthvalueKernel(const Context& dev_ctx, ...@@ -244,6 +266,7 @@ void KthvalueKernel(const Context& dev_ctx,
&trans_ind), &trans_ind),
true, true,
phi::errors::External("KthvalueOP: Error when use cub sorting")); phi::errors::External("KthvalueOP: Error when use cub sorting"));
#endif
funcs::TransCompute<phi::GPUContext, int64_t>( funcs::TransCompute<phi::GPUContext, int64_t>(
ndims, dev_ctx, trans_ind, indices, trans); ndims, dev_ctx, trans_ind, indices, trans);
funcs::TransCompute<phi::GPUContext, T>( funcs::TransCompute<phi::GPUContext, T>(
...@@ -263,6 +286,7 @@ PD_REGISTER_KERNEL(kthvalue, ...@@ -263,6 +286,7 @@ PD_REGISTER_KERNEL(kthvalue,
float, float,
double, double,
int, int,
int64_t) { int64_t,
phi::dtype::float16) {
kernel->OutputAt(1).SetDataType(phi::DataType::INT64); kernel->OutputAt(1).SetDataType(phi::DataType::INT64);
} }
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册