Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
b4b926f4
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
b4b926f4
编写于
11月 28, 2022
作者:
A
Asthestarsfalll
提交者:
GitHub
11月 28, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
migrate top_k_function_cuda.h from fluid to phi (#48251)
上级
923ad5dc
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
77 addition
and
90 deletion
+77
-90
paddle/fluid/operators/top_k_op.cu
paddle/fluid/operators/top_k_op.cu
+8
-8
paddle/phi/kernels/funcs/top_k_function_cuda.h
paddle/phi/kernels/funcs/top_k_function_cuda.h
+25
-31
paddle/phi/kernels/gpu/kthvalue_grad_kernel.cu
paddle/phi/kernels/gpu/kthvalue_grad_kernel.cu
+3
-3
paddle/phi/kernels/gpu/kthvalue_kernel.cu
paddle/phi/kernels/gpu/kthvalue_kernel.cu
+5
-7
paddle/phi/kernels/gpu/top_k_grad_kernel.cu
paddle/phi/kernels/gpu/top_k_grad_kernel.cu
+3
-5
paddle/phi/kernels/gpu/top_k_kernel.cu
paddle/phi/kernels/gpu/top_k_kernel.cu
+33
-36
未找到文件。
paddle/fluid/operators/top_k_op.cu
浏览文件 @
b4b926f4
...
...
@@ -22,9 +22,9 @@ limitations under the License. */
#include <hipcub/hipcub.hpp>
#endif
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/top_k_function_cuda.h"
#include "paddle/fluid/operators/top_k_op.h"
#include "paddle/fluid/platform/float16.h"
#include "paddle/phi/kernels/funcs/top_k_function_cuda.h"
// set cub base traits in order to handle float16
namespace
paddle
{
...
...
@@ -93,7 +93,7 @@ class TopkOpCUDAKernel : public framework::OpKernel<T> {
const
int64_t
input_width
=
inputdims
[
inputdims
.
size
()
-
1
];
const
auto
&
dev_ctx
=
ctx
.
cuda_device_context
();
if
((
input_width
<=
1024
||
k
>=
128
||
k
==
input_width
))
{
if
(
SortTopk
<
T
>
(
if
(
phi
::
funcs
::
SortTopk
<
T
>
(
dev_ctx
,
input
,
input_width
,
input_height
,
k
,
output
,
indices
))
{
// Successed, return.
return
;
...
...
@@ -110,12 +110,12 @@ class TopkOpCUDAKernel : public framework::OpKernel<T> {
// TODO(typhoonzero): refine this kernel.
const
int
kMaxHeight
=
2048
;
int
gridx
=
input_height
<
kMaxHeight
?
input_height
:
kMaxHeight
;
p
addle
::
platform
::
GpuLaunchConfig
config
=
p
addle
::
platform
::
GetGpuLaunchConfig1D
(
dev_ctx
,
input_width
);
p
hi
::
backends
::
gpu
::
GpuLaunchConfig
config
=
p
hi
::
backends
::
gpu
::
GetGpuLaunchConfig1D
(
dev_ctx
,
input_width
);
switch
(
config
.
thread_per_block
.
x
)
{
FIXED_BLOCK_DIM
(
switch
(
getMaxLength
(
k
))
{
FIXED_BLOCK_DIM
(
switch
(
phi
::
funcs
::
getMaxLength
(
k
))
{
FIXED_MAXLENGTH
(
KeMatrixTopK
<
T
,
maxLength
,
kBlockDim
>
phi
::
funcs
::
KeMatrixTopK
<
T
,
maxLength
,
kBlockDim
>
<<<
gridx
,
kBlockDim
,
0
,
dev_ctx
.
stream
()
>>>
(
output_data
,
k
,
indices_data
,
...
...
@@ -164,9 +164,9 @@ class TopkOpGradCUDAKernel : public framework::OpKernel<T> {
const
auto
&
dev_ctx
=
context
.
cuda_device_context
();
const
int
kMaxHeight
=
2048
;
int
gridx
=
row
<
kMaxHeight
?
row
:
kMaxHeight
;
switch
(
GetDesiredBlockDim
(
col
))
{
switch
(
phi
::
funcs
::
GetDesiredBlockDim
(
col
))
{
FIXED_BLOCK_DIM
(
AssignGrad
<
T
,
5
,
kBlockDim
>
phi
::
funcs
::
AssignGrad
<
T
,
5
,
kBlockDim
>
<<<
gridx
,
kBlockDim
,
0
,
dev_ctx
.
stream
()
>>>
(
x_grad_data
,
indices_data
,
out_grad_data
,
row
,
col
,
k
));
default:
...
...
paddle/
fluid/operator
s/top_k_function_cuda.h
→
paddle/
phi/kernels/func
s/top_k_function_cuda.h
浏览文件 @
b4b926f4
/* Copyright (c) 20
16
PaddlePaddle Authors. All Rights Reserved.
/* Copyright (c) 20
22
PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
...
...
@@ -23,21 +23,21 @@ limitations under the License. */
#ifdef __HIPCC__
#include <hipcub/hipcub.hpp>
#endif
#include "paddle/fluid/operators/eigen/eigen_function.h"
#include "paddle/fluid/operators/kernel_primitives/functor_primitives.h"
#include "paddle/fluid/operators/top_k_op.h"
#include "paddle/fluid/platform/device/gpu/gpu_launch_config.h"
#include "paddle/fluid/platform/float16.h"
#include "paddle/phi/backends/gpu/gpu_device_function.h"
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/backends/gpu/gpu_primitives.h"
#include "paddle/phi/common/float16.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/eigen/eigen_function.h"
#include "paddle/phi/kernels/primitive/functor_primitives.h"
#define FINAL_MASK 0xffffffff
#ifdef __HIPCC__
namespace
rocprim
{
namespace
detail
{
template
<
>
struct
radix_key_codec_base
<
p
addle
::
platform
::
float16
>
:
radix_key_codec_integral
<
p
addle
::
platform
::
float16
,
uint16_t
>
{};
struct
radix_key_codec_base
<
p
hi
::
dtype
::
float16
>
:
radix_key_codec_integral
<
p
hi
::
dtype
::
float16
,
uint16_t
>
{};
}
// namespace detail
}
// namespace rocprim
namespace
cub
=
hipcub
;
...
...
@@ -45,17 +45,13 @@ namespace cub = hipcub;
// set cub base traits in order to handle float16
namespace
cub
{
template
<
>
struct
NumericTraits
<
paddle
::
platform
::
float16
>
:
BaseTraits
<
FLOATING_POINT
,
true
,
false
,
uint16_t
,
paddle
::
platform
::
float16
>
{};
struct
NumericTraits
<
phi
::
dtype
::
float16
>
:
BaseTraits
<
FLOATING_POINT
,
true
,
false
,
uint16_t
,
phi
::
dtype
::
float16
>
{};
}
// namespace cub
#endif
namespace
p
addle
{
namespace
operator
s
{
namespace
p
hi
{
namespace
func
s
{
using
Tensor
=
phi
::
DenseTensor
;
...
...
@@ -553,10 +549,10 @@ struct RadixTypeConfig<int64_t> {
};
template
<
>
struct
RadixTypeConfig
<
p
latform
::
float16
>
{
struct
RadixTypeConfig
<
p
hi
::
dtype
::
float16
>
{
typedef
uint32_t
RadixType
;
static
inline
__device__
RadixType
Convert
(
p
latform
::
float16
v
)
{
static
inline
__device__
RadixType
Convert
(
p
hi
::
dtype
::
float16
v
)
{
#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__)
half
v_h
=
v
.
to_half
();
RadixType
x
=
__half_as_ushort
(
v_h
);
...
...
@@ -568,13 +564,13 @@ struct RadixTypeConfig<platform::float16> {
#endif
}
static
inline
__device__
p
latform
::
float16
Deconvert
(
RadixType
v
)
{
static
inline
__device__
p
hi
::
dtype
::
float16
Deconvert
(
RadixType
v
)
{
#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__)
RadixType
mask
=
(
v
&
0x00008000
)
?
0x00008000
:
0x0000ffff
;
return
static_cast
<
p
latform
::
float16
>
(
__ushort_as_half
(
v
^
mask
));
return
static_cast
<
p
hi
::
dtype
::
float16
>
(
__ushort_as_half
(
v
^
mask
));
#else
assert
(
false
);
return
static_cast
<
p
latform
::
float16
>
(
0
);
return
static_cast
<
p
hi
::
dtype
::
float16
>
(
0
);
#endif
}
};
...
...
@@ -819,7 +815,6 @@ __global__ void RadixTopK(const T* input,
int
slice_size
,
T
*
output
,
int64_t
*
indices
)
{
namespace
kps
=
paddle
::
operators
::
kernel_primitives
;
__shared__
int
shared_mem
[
32
];
// 1. Find the k-th value
...
...
@@ -1152,23 +1147,22 @@ bool SortTopk(const phi::GPUContext& ctx,
// copy sliced data to output.
const
Eigen
::
DSizes
<
Eigen
::
DenseIndex
,
2
>
slice_indices
{
0
,
0
};
const
Eigen
::
DSizes
<
Eigen
::
DenseIndex
,
2
>
slice_sizes
{
num_rows
,
k
};
auto
e_indices
=
framework
::
EigenMatrix
<
int64_t
>::
From
(
*
indices_tensor
,
dim
);
auto
e_tmp_indices
=
framework
::
EigenMatrix
<
int64_t
>::
From
(
auto
e_indices
=
phi
::
EigenMatrix
<
int64_t
>::
From
(
*
indices_tensor
,
dim
);
auto
e_tmp_indices
=
phi
::
EigenMatrix
<
int64_t
>::
From
(
static_cast
<
const
Tensor
>
(
temp_indices
));
std
::
vector
<
int
>
odims
=
{
static_cast
<
int
>
(
num_rows
),
static_cast
<
int
>
(
k
)};
auto
dim
=
phi
::
make_ddim
(
odims
);
auto
e_values
=
framework
::
EigenMatrix
<
T
>::
From
(
*
out_tensor
,
dim
);
auto
e_values
=
phi
::
EigenMatrix
<
T
>::
From
(
*
out_tensor
,
dim
);
auto
e_tmp_values
=
framework
::
EigenMatrix
<
T
>::
From
(
static_cast
<
const
Tensor
>
(
temp_values
));
phi
::
EigenMatrix
<
T
>::
From
(
static_cast
<
const
Tensor
>
(
temp_values
));
EigenSlice
<
std
::
decay_t
<
decltype
(
dev
)
>
,
int64_t
,
2
>::
Eval
(
phi
::
funcs
::
EigenSlice
<
std
::
decay_t
<
decltype
(
dev
)
>
,
int64_t
,
2
>::
Eval
(
dev
,
e_indices
,
e_tmp_indices
,
slice_indices
,
slice_sizes
);
EigenSlice
<
std
::
decay_t
<
decltype
(
dev
)
>
,
T
,
2
>::
Eval
(
phi
::
funcs
::
EigenSlice
<
std
::
decay_t
<
decltype
(
dev
)
>
,
T
,
2
>::
Eval
(
dev
,
e_values
,
e_tmp_values
,
slice_indices
,
slice_sizes
);
}
return
true
;
}
}
// namespace
operator
s
}
// namespace p
addle
}
// namespace
func
s
}
// namespace p
hi
paddle/phi/kernels/gpu/kthvalue_grad_kernel.cu
浏览文件 @
b4b926f4
...
...
@@ -14,9 +14,9 @@
#include "paddle/phi/kernels/kthvalue_grad_kernel.h"
#include "paddle/fluid/operators/top_k_function_cuda.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/top_k_function_cuda.h"
namespace
phi
{
static
int
getBlockSize
(
int
col
)
{
...
...
@@ -48,12 +48,12 @@ void KthvalueGradKernel(const Context& dev_ctx,
const
T
*
out_grad_data
=
d_out
.
data
<
T
>
();
const
int64_t
*
indices_data
=
indices
.
data
<
int64_t
>
();
int
pre
,
n
,
post
;
p
addle
::
operator
s
::
GetDims
(
in_dims
,
axis
,
&
pre
,
&
n
,
&
post
);
p
hi
::
func
s
::
GetDims
(
in_dims
,
axis
,
&
pre
,
&
n
,
&
post
);
int
block_size
=
getBlockSize
(
post
*
k
);
int
max_threads
=
dev_ctx
.
GetMaxPhysicalThreadCount
();
const
int
max_blocks
=
std
::
max
(((
max_threads
-
1
)
/
block_size
+
1
),
1
);
int
grid_size
=
std
::
min
(
max_blocks
,
pre
);
p
addle
::
operator
s
::
AssignGradWithAxis
<
T
>
p
hi
::
func
s
::
AssignGradWithAxis
<
T
>
<<<
grid_size
,
block_size
,
64
*
4
,
dev_ctx
.
stream
()
>>>
(
out_grad_data
,
indices_data
,
x_grad_data
,
pre
,
post
,
n
,
1
);
}
...
...
paddle/phi/kernels/gpu/kthvalue_kernel.cu
浏览文件 @
b4b926f4
...
...
@@ -14,12 +14,12 @@
#include "paddle/phi/kernels/kthvalue_kernel.h"
#include "paddle/fluid/operators/top_k_function_cuda.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/eigen/eigen_function.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/funcs/top_k_function_cuda.h"
namespace
phi
{
inline
int
getBlockSize
(
int
col
)
{
...
...
@@ -55,15 +55,13 @@ bool SortKthvalue(const phi::GPUContext& dev_ctx,
unsigned
int
grid_size
=
num_rows
<
maxGridDimX
?
static_cast
<
unsigned
int
>
(
num_rows
)
:
maxGridDimX
;
paddle
::
operators
::
InitIndex
<
int64_t
>
<<<
grid_size
,
block_size
,
0
,
cu_stream
>>>
(
input_indices
.
data
<
int64_t
>
(),
num_rows
,
num_cols
);
phi
::
funcs
::
InitIndex
<
int64_t
><<<
grid_size
,
block_size
,
0
,
cu_stream
>>>
(
input_indices
.
data
<
int64_t
>
(),
num_rows
,
num_cols
);
cub
::
CountingInputIterator
<
int64_t
>
counting_iter
(
0
);
cub
::
TransformInputIterator
<
int64_t
,
p
addle
::
operator
s
::
SegmentOffsetIter
,
p
hi
::
func
s
::
SegmentOffsetIter
,
cub
::
CountingInputIterator
<
int64_t
>>
segment_offsets_t
(
counting_iter
,
paddle
::
operators
::
SegmentOffsetIter
(
num_cols
));
segment_offsets_t
(
counting_iter
,
phi
::
funcs
::
SegmentOffsetIter
(
num_cols
));
T
*
sorted_values_ptr
;
int64_t
*
sorted_indices_ptr
;
DenseTensor
temp_values
,
temp_indices
;
...
...
paddle/phi/kernels/gpu/top_k_grad_kernel.cu
浏览文件 @
b4b926f4
...
...
@@ -14,15 +14,13 @@
#include "paddle/phi/kernels/top_k_grad_kernel.h"
#include "paddle/fluid/operators/top_k_function_cuda.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/funcs/top_k_function_cuda.h"
namespace
phi
{
namespace
ops
=
paddle
::
operators
;
template
<
typename
T
,
typename
Context
>
void
TopkGradKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
x
,
...
...
@@ -50,7 +48,7 @@ void TopkGradKernel(const Context& dev_ctx,
const
int64_t
*
indices_data
=
indices
.
data
<
int64_t
>
();
int
pre
,
n
,
post
;
op
s
::
GetDims
(
in_dims
,
axis
,
&
pre
,
&
n
,
&
post
);
phi
::
func
s
::
GetDims
(
in_dims
,
axis
,
&
pre
,
&
n
,
&
post
);
// calcluate the block and grid num
auto
ComputeBlockSize
=
[](
int
col
)
{
...
...
@@ -71,7 +69,7 @@ void TopkGradKernel(const Context& dev_ctx,
int
grid_size
=
std
::
min
(
max_blocks
,
pre
);
// lanuch the cuda kernel to assign the grad
op
s
::
AssignGradWithAxis
<
T
>
phi
::
func
s
::
AssignGradWithAxis
<
T
>
<<<
grid_size
,
block_size
,
64
*
4
,
dev_ctx
.
stream
()
>>>
(
out_grad_data
,
indices_data
,
x_grad_data
,
pre
,
post
,
n
,
k
);
}
...
...
paddle/phi/kernels/gpu/top_k_kernel.cu
浏览文件 @
b4b926f4
...
...
@@ -14,17 +14,14 @@
#include "paddle/phi/kernels/top_k_kernel.h"
#include "paddle/fluid/operators/top_k_function_cuda.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/tensor_utils.h"
#include "paddle/phi/kernels/funcs/gather.cu.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/funcs/top_k_function_cuda.h"
namespace
phi
{
namespace
ops
=
paddle
::
operators
;
#define FIXED_BLOCK_DIM_BASE(dim, ...) \
case (dim): { \
constexpr auto kBlockDim = (dim); \
...
...
@@ -95,14 +92,14 @@ void TopkKernel(const Context& dev_ctx,
// statistics
if
(
input_width
>=
128
&&
k
>=
input_width
*
0.75
)
{
auto
*
ctx
=
reinterpret_cast
<
const
phi
::
GPUContext
*>
(
&
dev_ctx
);
if
(
op
s
::
SortTopk
<
T
>
(
*
ctx
,
input
,
input_width
,
input_height
,
k
,
out
,
indices
,
largest
))
{
if
(
phi
::
func
s
::
SortTopk
<
T
>
(
*
ctx
,
input
,
input_width
,
input_height
,
k
,
out
,
indices
,
largest
))
{
// Successed, return.
return
;
}
else
{
...
...
@@ -116,7 +113,7 @@ void TopkKernel(const Context& dev_ctx,
// 1. Gather TopK, but without sorting
constexpr
int
max_num_threads
=
1024
;
if
(
largest
)
{
op
s
::
RadixTopK
<
T
,
true
>
phi
::
func
s
::
RadixTopK
<
T
,
true
>
<<<
input_height
,
max_num_threads
,
0
,
dev_ctx
.
stream
()
>>>
(
input_data
,
k
,
...
...
@@ -125,7 +122,7 @@ void TopkKernel(const Context& dev_ctx,
output_data
,
indices_data
);
}
else
{
op
s
::
RadixTopK
<
T
,
false
>
phi
::
func
s
::
RadixTopK
<
T
,
false
>
<<<
input_height
,
max_num_threads
,
0
,
dev_ctx
.
stream
()
>>>
(
input_data
,
k
,
...
...
@@ -146,14 +143,14 @@ void TopkKernel(const Context& dev_ctx,
dev_ctx
.
template
Alloc
<
int64_t
>(
&
sorted_indices
);
dev_ctx
.
template
Alloc
<
int64_t
>(
&
gather_indices
);
auto
*
ctx
=
reinterpret_cast
<
const
phi
::
GPUContext
*>
(
&
dev_ctx
);
if
(
op
s
::
SortTopk
<
T
>
(
*
ctx
,
out
,
k
,
input_height
,
k
,
&
sorted_output
,
&
sorted_indices
,
largest
))
{
if
(
phi
::
func
s
::
SortTopk
<
T
>
(
*
ctx
,
out
,
k
,
input_height
,
k
,
&
sorted_output
,
&
sorted_indices
,
largest
))
{
funcs
::
GPUGather
<
int64_t
,
int64_t
>
(
dev_ctx
,
*
indices
,
sorted_indices
,
&
gather_indices
);
Copy
(
dev_ctx
,
gather_indices
,
indices
->
place
(),
false
,
indices
);
...
...
@@ -178,7 +175,7 @@ void TopkKernel(const Context& dev_ctx,
switch
(
config
.
thread_per_block
.
x
)
{
#ifdef PADDLE_WITH_HIP
FIXED_BLOCK_DIM
(
op
s
::
KeMatrixTopK
<
T
,
20
,
kBlockDim
>
phi
::
func
s
::
KeMatrixTopK
<
T
,
20
,
kBlockDim
>
<<<
gridx
,
kBlockDim
,
0
,
dev_ctx
.
stream
()
>>>
(
output_data
,
k
,
indices_data
,
...
...
@@ -190,9 +187,9 @@ void TopkKernel(const Context& dev_ctx,
input_height
,
largest
));
#else
FIXED_BLOCK_DIM
(
switch
(
op
s
::
getMaxLength
(
k
))
{
FIXED_BLOCK_DIM
(
switch
(
phi
::
func
s
::
getMaxLength
(
k
))
{
FIXED_MAXLENGTH
(
op
s
::
KeMatrixTopK
<
T
,
maxLength
,
kBlockDim
>
phi
::
func
s
::
KeMatrixTopK
<
T
,
maxLength
,
kBlockDim
>
<<<
gridx
,
kBlockDim
,
0
,
dev_ctx
.
stream
()
>>>
(
output_data
,
k
,
indices_data
,
...
...
@@ -260,14 +257,14 @@ void TopkKernel(const Context& dev_ctx,
// statistics
if
(
input_width
>=
128
&&
k
>=
input_width
*
0.75
)
{
auto
*
ctx
=
reinterpret_cast
<
const
phi
::
GPUContext
*>
(
&
dev_ctx
);
if
(
op
s
::
SortTopk
<
T
>
(
*
ctx
,
&
trans_input
,
input_width
,
input_height
,
k
,
&
trans_out
,
&
trans_ind
,
largest
))
{
if
(
phi
::
func
s
::
SortTopk
<
T
>
(
*
ctx
,
&
trans_input
,
input_width
,
input_height
,
k
,
&
trans_out
,
&
trans_ind
,
largest
))
{
// last step, tranpose back the indices and output
funcs
::
TransCompute
<
phi
::
GPUContext
,
int64_t
>
(
ndims
,
dev_ctx
,
trans_ind
,
indices
,
trans
);
...
...
@@ -287,7 +284,7 @@ void TopkKernel(const Context& dev_ctx,
switch
(
config
.
thread_per_block
.
x
)
{
#ifdef PADDLE_WITH_HIP
FIXED_BLOCK_DIM
(
op
s
::
KeMatrixTopK
<
T
,
20
,
kBlockDim
>
phi
::
func
s
::
KeMatrixTopK
<
T
,
20
,
kBlockDim
>
<<<
gridx
,
kBlockDim
,
0
,
dev_ctx
.
stream
()
>>>
(
trans_out
.
data
<
T
>
(),
k
,
trans_ind
.
data
<
int64_t
>
(),
...
...
@@ -299,8 +296,8 @@ void TopkKernel(const Context& dev_ctx,
input_height
,
largest
));
#else
FIXED_BLOCK_DIM
(
switch
(
op
s
::
getMaxLength
(
k
))
{
FIXED_MAXLENGTH
(
op
s
::
KeMatrixTopK
<
T
,
maxLength
,
kBlockDim
>
FIXED_BLOCK_DIM
(
switch
(
phi
::
func
s
::
getMaxLength
(
k
))
{
FIXED_MAXLENGTH
(
phi
::
func
s
::
KeMatrixTopK
<
T
,
maxLength
,
kBlockDim
>
<<<
gridx
,
kBlockDim
,
0
,
dev_ctx
.
stream
()
>>>
(
trans_out
.
data
<
T
>
(),
k
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录