未验证 提交 b4b926f4 编写于 作者: A Asthestarsfalll 提交者: GitHub

migrate top_k_function_cuda.h from fluid to phi (#48251)

上级 923ad5dc
......@@ -22,9 +22,9 @@ limitations under the License. */
#include <hipcub/hipcub.hpp>
#endif
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/top_k_function_cuda.h"
#include "paddle/fluid/operators/top_k_op.h"
#include "paddle/fluid/platform/float16.h"
#include "paddle/phi/kernels/funcs/top_k_function_cuda.h"
// set cub base traits in order to handle float16
namespace paddle {
......@@ -93,7 +93,7 @@ class TopkOpCUDAKernel : public framework::OpKernel<T> {
const int64_t input_width = inputdims[inputdims.size() - 1];
const auto& dev_ctx = ctx.cuda_device_context();
if ((input_width <= 1024 || k >= 128 || k == input_width)) {
if (SortTopk<T>(
if (phi::funcs::SortTopk<T>(
dev_ctx, input, input_width, input_height, k, output, indices)) {
// Successed, return.
return;
......@@ -110,12 +110,12 @@ class TopkOpCUDAKernel : public framework::OpKernel<T> {
// TODO(typhoonzero): refine this kernel.
const int kMaxHeight = 2048;
int gridx = input_height < kMaxHeight ? input_height : kMaxHeight;
paddle::platform::GpuLaunchConfig config =
paddle::platform::GetGpuLaunchConfig1D(dev_ctx, input_width);
phi::backends::gpu::GpuLaunchConfig config =
phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, input_width);
switch (config.thread_per_block.x) {
FIXED_BLOCK_DIM(switch (getMaxLength(k)) {
FIXED_BLOCK_DIM(switch (phi::funcs::getMaxLength(k)) {
FIXED_MAXLENGTH(
KeMatrixTopK<T, maxLength, kBlockDim>
phi::funcs::KeMatrixTopK<T, maxLength, kBlockDim>
<<<gridx, kBlockDim, 0, dev_ctx.stream()>>>(output_data,
k,
indices_data,
......@@ -164,9 +164,9 @@ class TopkOpGradCUDAKernel : public framework::OpKernel<T> {
const auto& dev_ctx = context.cuda_device_context();
const int kMaxHeight = 2048;
int gridx = row < kMaxHeight ? row : kMaxHeight;
switch (GetDesiredBlockDim(col)) {
switch (phi::funcs::GetDesiredBlockDim(col)) {
FIXED_BLOCK_DIM(
AssignGrad<T, 5, kBlockDim>
phi::funcs::AssignGrad<T, 5, kBlockDim>
<<<gridx, kBlockDim, 0, dev_ctx.stream()>>>(
x_grad_data, indices_data, out_grad_data, row, col, k));
default:
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
......@@ -23,21 +23,21 @@ limitations under the License. */
#ifdef __HIPCC__
#include <hipcub/hipcub.hpp>
#endif
#include "paddle/fluid/operators/eigen/eigen_function.h"
#include "paddle/fluid/operators/kernel_primitives/functor_primitives.h"
#include "paddle/fluid/operators/top_k_op.h"
#include "paddle/fluid/platform/device/gpu/gpu_launch_config.h"
#include "paddle/fluid/platform/float16.h"
#include "paddle/phi/backends/gpu/gpu_device_function.h"
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/backends/gpu/gpu_primitives.h"
#include "paddle/phi/common/float16.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/eigen/eigen_function.h"
#include "paddle/phi/kernels/primitive/functor_primitives.h"
#define FINAL_MASK 0xffffffff
#ifdef __HIPCC__
namespace rocprim {
namespace detail {
template <>
struct radix_key_codec_base<paddle::platform::float16>
: radix_key_codec_integral<paddle::platform::float16, uint16_t> {};
struct radix_key_codec_base<phi::dtype::float16>
: radix_key_codec_integral<phi::dtype::float16, uint16_t> {};
} // namespace detail
} // namespace rocprim
namespace cub = hipcub;
......@@ -45,17 +45,13 @@ namespace cub = hipcub;
// set cub base traits in order to handle float16
namespace cub {
template <>
struct NumericTraits<paddle::platform::float16>
: BaseTraits<FLOATING_POINT,
true,
false,
uint16_t,
paddle::platform::float16> {};
struct NumericTraits<phi::dtype::float16>
: BaseTraits<FLOATING_POINT, true, false, uint16_t, phi::dtype::float16> {};
} // namespace cub
#endif
namespace paddle {
namespace operators {
namespace phi {
namespace funcs {
using Tensor = phi::DenseTensor;
......@@ -553,10 +549,10 @@ struct RadixTypeConfig<int64_t> {
};
template <>
struct RadixTypeConfig<platform::float16> {
struct RadixTypeConfig<phi::dtype::float16> {
typedef uint32_t RadixType;
static inline __device__ RadixType Convert(platform::float16 v) {
static inline __device__ RadixType Convert(phi::dtype::float16 v) {
#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__)
half v_h = v.to_half();
RadixType x = __half_as_ushort(v_h);
......@@ -568,13 +564,13 @@ struct RadixTypeConfig<platform::float16> {
#endif
}
static inline __device__ platform::float16 Deconvert(RadixType v) {
static inline __device__ phi::dtype::float16 Deconvert(RadixType v) {
#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__)
RadixType mask = (v & 0x00008000) ? 0x00008000 : 0x0000ffff;
return static_cast<platform::float16>(__ushort_as_half(v ^ mask));
return static_cast<phi::dtype::float16>(__ushort_as_half(v ^ mask));
#else
assert(false);
return static_cast<platform::float16>(0);
return static_cast<phi::dtype::float16>(0);
#endif
}
};
......@@ -819,7 +815,6 @@ __global__ void RadixTopK(const T* input,
int slice_size,
T* output,
int64_t* indices) {
namespace kps = paddle::operators::kernel_primitives;
__shared__ int shared_mem[32];
// 1. Find the k-th value
......@@ -1152,23 +1147,22 @@ bool SortTopk(const phi::GPUContext& ctx,
// copy sliced data to output.
const Eigen::DSizes<Eigen::DenseIndex, 2> slice_indices{0, 0};
const Eigen::DSizes<Eigen::DenseIndex, 2> slice_sizes{num_rows, k};
auto e_indices =
framework::EigenMatrix<int64_t>::From(*indices_tensor, dim);
auto e_tmp_indices = framework::EigenMatrix<int64_t>::From(
auto e_indices = phi::EigenMatrix<int64_t>::From(*indices_tensor, dim);
auto e_tmp_indices = phi::EigenMatrix<int64_t>::From(
static_cast<const Tensor>(temp_indices));
std::vector<int> odims = {static_cast<int>(num_rows), static_cast<int>(k)};
auto dim = phi::make_ddim(odims);
auto e_values = framework::EigenMatrix<T>::From(*out_tensor, dim);
auto e_values = phi::EigenMatrix<T>::From(*out_tensor, dim);
auto e_tmp_values =
framework::EigenMatrix<T>::From(static_cast<const Tensor>(temp_values));
phi::EigenMatrix<T>::From(static_cast<const Tensor>(temp_values));
EigenSlice<std::decay_t<decltype(dev)>, int64_t, 2>::Eval(
phi::funcs::EigenSlice<std::decay_t<decltype(dev)>, int64_t, 2>::Eval(
dev, e_indices, e_tmp_indices, slice_indices, slice_sizes);
EigenSlice<std::decay_t<decltype(dev)>, T, 2>::Eval(
phi::funcs::EigenSlice<std::decay_t<decltype(dev)>, T, 2>::Eval(
dev, e_values, e_tmp_values, slice_indices, slice_sizes);
}
return true;
}
} // namespace operators
} // namespace paddle
} // namespace funcs
} // namespace phi
......@@ -14,9 +14,9 @@
#include "paddle/phi/kernels/kthvalue_grad_kernel.h"
#include "paddle/fluid/operators/top_k_function_cuda.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/top_k_function_cuda.h"
namespace phi {
static int getBlockSize(int col) {
......@@ -48,12 +48,12 @@ void KthvalueGradKernel(const Context& dev_ctx,
const T* out_grad_data = d_out.data<T>();
const int64_t* indices_data = indices.data<int64_t>();
int pre, n, post;
paddle::operators::GetDims(in_dims, axis, &pre, &n, &post);
phi::funcs::GetDims(in_dims, axis, &pre, &n, &post);
int block_size = getBlockSize(post * k);
int max_threads = dev_ctx.GetMaxPhysicalThreadCount();
const int max_blocks = std::max(((max_threads - 1) / block_size + 1), 1);
int grid_size = std::min(max_blocks, pre);
paddle::operators::AssignGradWithAxis<T>
phi::funcs::AssignGradWithAxis<T>
<<<grid_size, block_size, 64 * 4, dev_ctx.stream()>>>(
out_grad_data, indices_data, x_grad_data, pre, post, n, 1);
}
......
......@@ -14,12 +14,12 @@
#include "paddle/phi/kernels/kthvalue_kernel.h"
#include "paddle/fluid/operators/top_k_function_cuda.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/eigen/eigen_function.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/funcs/top_k_function_cuda.h"
namespace phi {
inline int getBlockSize(int col) {
......@@ -55,15 +55,13 @@ bool SortKthvalue(const phi::GPUContext& dev_ctx,
unsigned int grid_size = num_rows < maxGridDimX
? static_cast<unsigned int>(num_rows)
: maxGridDimX;
paddle::operators::InitIndex<int64_t>
<<<grid_size, block_size, 0, cu_stream>>>(
input_indices.data<int64_t>(), num_rows, num_cols);
phi::funcs::InitIndex<int64_t><<<grid_size, block_size, 0, cu_stream>>>(
input_indices.data<int64_t>(), num_rows, num_cols);
cub::CountingInputIterator<int64_t> counting_iter(0);
cub::TransformInputIterator<int64_t,
paddle::operators::SegmentOffsetIter,
phi::funcs::SegmentOffsetIter,
cub::CountingInputIterator<int64_t>>
segment_offsets_t(counting_iter,
paddle::operators::SegmentOffsetIter(num_cols));
segment_offsets_t(counting_iter, phi::funcs::SegmentOffsetIter(num_cols));
T* sorted_values_ptr;
int64_t* sorted_indices_ptr;
DenseTensor temp_values, temp_indices;
......
......@@ -14,15 +14,13 @@
#include "paddle/phi/kernels/top_k_grad_kernel.h"
#include "paddle/fluid/operators/top_k_function_cuda.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/funcs/top_k_function_cuda.h"
namespace phi {
namespace ops = paddle::operators;
template <typename T, typename Context>
void TopkGradKernel(const Context& dev_ctx,
const DenseTensor& x,
......@@ -50,7 +48,7 @@ void TopkGradKernel(const Context& dev_ctx,
const int64_t* indices_data = indices.data<int64_t>();
int pre, n, post;
ops::GetDims(in_dims, axis, &pre, &n, &post);
phi::funcs::GetDims(in_dims, axis, &pre, &n, &post);
// calcluate the block and grid num
auto ComputeBlockSize = [](int col) {
......@@ -71,7 +69,7 @@ void TopkGradKernel(const Context& dev_ctx,
int grid_size = std::min(max_blocks, pre);
// lanuch the cuda kernel to assign the grad
ops::AssignGradWithAxis<T>
phi::funcs::AssignGradWithAxis<T>
<<<grid_size, block_size, 64 * 4, dev_ctx.stream()>>>(
out_grad_data, indices_data, x_grad_data, pre, post, n, k);
}
......
......@@ -14,17 +14,14 @@
#include "paddle/phi/kernels/top_k_kernel.h"
#include "paddle/fluid/operators/top_k_function_cuda.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/tensor_utils.h"
#include "paddle/phi/kernels/funcs/gather.cu.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/funcs/top_k_function_cuda.h"
namespace phi {
namespace ops = paddle::operators;
#define FIXED_BLOCK_DIM_BASE(dim, ...) \
case (dim): { \
constexpr auto kBlockDim = (dim); \
......@@ -95,14 +92,14 @@ void TopkKernel(const Context& dev_ctx,
// statistics
if (input_width >= 128 && k >= input_width * 0.75) {
auto* ctx = reinterpret_cast<const phi::GPUContext*>(&dev_ctx);
if (ops::SortTopk<T>(*ctx,
input,
input_width,
input_height,
k,
out,
indices,
largest)) {
if (phi::funcs::SortTopk<T>(*ctx,
input,
input_width,
input_height,
k,
out,
indices,
largest)) {
// Successed, return.
return;
} else {
......@@ -116,7 +113,7 @@ void TopkKernel(const Context& dev_ctx,
// 1. Gather TopK, but without sorting
constexpr int max_num_threads = 1024;
if (largest) {
ops::RadixTopK<T, true>
phi::funcs::RadixTopK<T, true>
<<<input_height, max_num_threads, 0, dev_ctx.stream()>>>(
input_data,
k,
......@@ -125,7 +122,7 @@ void TopkKernel(const Context& dev_ctx,
output_data,
indices_data);
} else {
ops::RadixTopK<T, false>
phi::funcs::RadixTopK<T, false>
<<<input_height, max_num_threads, 0, dev_ctx.stream()>>>(
input_data,
k,
......@@ -146,14 +143,14 @@ void TopkKernel(const Context& dev_ctx,
dev_ctx.template Alloc<int64_t>(&sorted_indices);
dev_ctx.template Alloc<int64_t>(&gather_indices);
auto* ctx = reinterpret_cast<const phi::GPUContext*>(&dev_ctx);
if (ops::SortTopk<T>(*ctx,
out,
k,
input_height,
k,
&sorted_output,
&sorted_indices,
largest)) {
if (phi::funcs::SortTopk<T>(*ctx,
out,
k,
input_height,
k,
&sorted_output,
&sorted_indices,
largest)) {
funcs::GPUGather<int64_t, int64_t>(
dev_ctx, *indices, sorted_indices, &gather_indices);
Copy(dev_ctx, gather_indices, indices->place(), false, indices);
......@@ -178,7 +175,7 @@ void TopkKernel(const Context& dev_ctx,
switch (config.thread_per_block.x) {
#ifdef PADDLE_WITH_HIP
FIXED_BLOCK_DIM(
ops::KeMatrixTopK<T, 20, kBlockDim>
phi::funcs::KeMatrixTopK<T, 20, kBlockDim>
<<<gridx, kBlockDim, 0, dev_ctx.stream()>>>(output_data,
k,
indices_data,
......@@ -190,9 +187,9 @@ void TopkKernel(const Context& dev_ctx,
input_height,
largest));
#else
FIXED_BLOCK_DIM(switch (ops::getMaxLength(k)) {
FIXED_BLOCK_DIM(switch (phi::funcs::getMaxLength(k)) {
FIXED_MAXLENGTH(
ops::KeMatrixTopK<T, maxLength, kBlockDim>
phi::funcs::KeMatrixTopK<T, maxLength, kBlockDim>
<<<gridx, kBlockDim, 0, dev_ctx.stream()>>>(output_data,
k,
indices_data,
......@@ -260,14 +257,14 @@ void TopkKernel(const Context& dev_ctx,
// statistics
if (input_width >= 128 && k >= input_width * 0.75) {
auto* ctx = reinterpret_cast<const phi::GPUContext*>(&dev_ctx);
if (ops::SortTopk<T>(*ctx,
&trans_input,
input_width,
input_height,
k,
&trans_out,
&trans_ind,
largest)) {
if (phi::funcs::SortTopk<T>(*ctx,
&trans_input,
input_width,
input_height,
k,
&trans_out,
&trans_ind,
largest)) {
// last step, tranpose back the indices and output
funcs::TransCompute<phi::GPUContext, int64_t>(
ndims, dev_ctx, trans_ind, indices, trans);
......@@ -287,7 +284,7 @@ void TopkKernel(const Context& dev_ctx,
switch (config.thread_per_block.x) {
#ifdef PADDLE_WITH_HIP
FIXED_BLOCK_DIM(
ops::KeMatrixTopK<T, 20, kBlockDim>
phi::funcs::KeMatrixTopK<T, 20, kBlockDim>
<<<gridx, kBlockDim, 0, dev_ctx.stream()>>>(trans_out.data<T>(),
k,
trans_ind.data<int64_t>(),
......@@ -299,8 +296,8 @@ void TopkKernel(const Context& dev_ctx,
input_height,
largest));
#else
FIXED_BLOCK_DIM(switch (ops::getMaxLength(k)) {
FIXED_MAXLENGTH(ops::KeMatrixTopK<T, maxLength, kBlockDim>
FIXED_BLOCK_DIM(switch (phi::funcs::getMaxLength(k)) {
FIXED_MAXLENGTH(phi::funcs::KeMatrixTopK<T, maxLength, kBlockDim>
<<<gridx, kBlockDim, 0, dev_ctx.stream()>>>(
trans_out.data<T>(),
k,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册