Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
41e4f7ea
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
41e4f7ea
编写于
10月 08, 2018
作者:
Q
qingqing01
提交者:
GitHub
10月 08, 2018
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Optimize Topk when height is large. (#13710)
上级
65ed45a1
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
64 addition
and
27 deletion
+64
-27
paddle/fluid/operators/top_k_op.cu
paddle/fluid/operators/top_k_op.cu
+64
-27
未找到文件。
paddle/fluid/operators/top_k_op.cu
浏览文件 @
41e4f7ea
...
@@ -256,36 +256,65 @@ __device__ __forceinline__ void BlockReduce(Pair<T>* sh_topk, int* maxid,
...
@@ -256,36 +256,65 @@ __device__ __forceinline__ void BlockReduce(Pair<T>* sh_topk, int* maxid,
* 3. go to the second setp, until one thread's topk value is null;
* 3. go to the second setp, until one thread's topk value is null;
* 4. go to the first setp, until get the topk value.
* 4. go to the first setp, until get the topk value.
*/
*/
template
<
typename
T
,
int
MaxLength
,
int
BlockSize
>
template
<
typename
T
,
int
MaxLength
,
int
BlockSize
>
__global__
void
KeMatrixTopK
(
T
*
output
,
int
output_stride
,
int64_t
*
indices
,
__global__
void
KeMatrixTopK
(
T
*
output
,
int
output_stride
,
int64_t
*
indices
,
const
T
*
src
,
int
lds
,
int
dim
,
int
k
)
{
const
T
*
src
,
int
lds
,
int
dim
,
int
k
,
int
grid_dim
,
int
num
)
{
__shared__
Pair
<
T
>
sh_topk
[
BlockSize
];
__shared__
Pair
<
T
>
sh_topk
[
BlockSize
];
__shared__
int
maxid
[
BlockSize
/
2
];
__shared__
int
maxid
[
BlockSize
/
2
];
const
int
tid
=
threadIdx
.
x
;
const
int
tid
=
threadIdx
.
x
;
const
int
warp
=
threadIdx
.
x
/
32
;
const
int
warp
=
threadIdx
.
x
/
32
;
output
+=
blockIdx
.
x
*
output_stride
;
indices
+=
blockIdx
.
x
*
k
;
Pair
<
T
>
topk
[
MaxLength
];
const
int
bid
=
blockIdx
.
x
;
int
beam
=
MaxLength
;
for
(
int
i
=
bid
;
i
<
num
;
i
+=
grid_dim
)
{
Pair
<
T
>
max
;
output
+=
i
*
output_stride
;
bool
is_empty
=
false
;
indices
+=
i
*
k
;
bool
firststep
=
true
;
Pair
<
T
>
topk
[
MaxLength
];
int
beam
=
MaxLength
;
Pair
<
T
>
max
;
bool
is_empty
=
false
;
bool
firststep
=
true
;
for
(
int
k
=
0
;
k
<
MaxLength
;
k
++
)
{
topk
[
k
].
set
(
-
INFINITY
,
-
1
);
}
while
(
k
)
{
ThreadGetTopK
<
T
,
MaxLength
,
BlockSize
>
(
topk
,
&
beam
,
k
,
src
+
i
*
lds
,
&
firststep
,
&
is_empty
,
&
max
,
dim
,
tid
);
for
(
int
k
=
0
;
k
<
MaxLength
;
k
++
)
{
sh_topk
[
tid
]
=
topk
[
0
];
topk
[
k
].
set
(
-
INFINITY
,
-
1
);
BlockReduce
<
T
,
MaxLength
,
BlockSize
>
(
sh_topk
,
maxid
,
topk
,
&
output
,
&
indices
,
&
beam
,
&
k
,
tid
,
warp
);
}
}
}
while
(
k
)
{
}
ThreadGetTopK
<
T
,
MaxLength
,
BlockSize
>
(
topk
,
&
beam
,
k
,
src
+
blockIdx
.
x
*
lds
,
&
firststep
,
inline
static
int
GetDesiredBlockDim
(
int
dim
)
{
&
is_empty
,
&
max
,
dim
,
tid
);
if
(
dim
>
128
)
{
return
256
;
sh_topk
[
tid
]
=
topk
[
0
];
}
else
if
(
dim
>
64
)
{
BlockReduce
<
T
,
MaxLength
,
BlockSize
>
(
sh_topk
,
maxid
,
topk
,
&
output
,
return
128
;
&
indices
,
&
beam
,
&
k
,
tid
,
warp
);
}
else
if
(
dim
>
32
)
{
return
64
;
}
else
{
return
32
;
}
}
}
}
#define FIXED_BLOCK_DIM_BASE(dim, ...) \
case (dim): { \
constexpr auto kBlockDim = (dim); \
__VA_ARGS__; \
} break
#define FIXED_BLOCK_DIM(...) \
FIXED_BLOCK_DIM_BASE(256, ##__VA_ARGS__); \
FIXED_BLOCK_DIM_BASE(128, ##__VA_ARGS__); \
FIXED_BLOCK_DIM_BASE(64, ##__VA_ARGS__); \
FIXED_BLOCK_DIM_BASE(32, ##__VA_ARGS__)
template
<
typename
T
>
template
<
typename
T
>
class
TopkOpCUDAKernel
:
public
framework
::
OpKernel
<
T
>
{
class
TopkOpCUDAKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
public:
...
@@ -310,18 +339,26 @@ class TopkOpCUDAKernel : public framework::OpKernel<T> {
...
@@ -310,18 +339,26 @@ class TopkOpCUDAKernel : public framework::OpKernel<T> {
// NOTE: pass lds and dim same to input width.
// NOTE: pass lds and dim same to input width.
// NOTE: old matrix implementation of stride is different to eigen.
// NOTE: old matrix implementation of stride is different to eigen.
// TODO(typhoonzero): refine this kernel.
// TODO(typhoonzero): refine this kernel.
dim3
threads
(
256
,
1
);
const
int
kMaxHeight
=
2048
;
dim3
grid
(
input_height
,
1
);
int
gridx
=
input_height
<
kMaxHeight
?
input_height
:
kMaxHeight
;
auto
&
dev_ctx
=
ctx
.
cuda_device_context
();
KeMatrixTopK
<
T
,
5
,
256
><<<
grid
,
threads
,
0
,
reinterpret_cast
<
const
platform
::
CUDADeviceContext
&>
(
switch
(
GetDesiredBlockDim
(
input_width
))
{
ctx
.
device_context
())
FIXED_BLOCK_DIM
(
.
stream
()
>>>
(
KeMatrixTopK
<
T
,
5
,
output_data
,
output
->
dims
()[
1
],
indices_data
,
input_data
,
input_width
,
kBlockDim
><<<
gridx
,
kBlockDim
,
0
,
dev_ctx
.
stream
()
>>>
(
input_width
,
static_cast
<
int
>
(
k
));
output_data
,
output
->
dims
()[
1
],
indices_data
,
input_data
,
input_width
,
input_width
,
static_cast
<
int
>
(
k
),
gridx
,
input_height
));
default:
PADDLE_THROW
(
"Error"
);
}
}
}
};
};
#undef FIXED_BLOCK_DIM_BASE
#undef FIXED_BLOCK_DIM
}
// namespace operators
}
// namespace operators
}
// namespace paddle
}
// namespace paddle
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录