Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleDetection
提交
41e4f7ea
P
PaddleDetection
项目概览
PaddlePaddle
/
PaddleDetection
大约 1 年 前同步成功
通知
695
Star
11112
Fork
2696
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
184
列表
看板
标记
里程碑
合并请求
40
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
184
Issue
184
列表
看板
标记
里程碑
合并请求
40
合并请求
40
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
41e4f7ea
编写于
10月 08, 2018
作者:
Q
qingqing01
提交者:
GitHub
10月 08, 2018
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Optimize Topk when height is large. (#13710)
上级
65ed45a1
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
64 addition
and
27 deletion
+64
-27
paddle/fluid/operators/top_k_op.cu
paddle/fluid/operators/top_k_op.cu
+64
-27
未找到文件。
paddle/fluid/operators/top_k_op.cu
浏览文件 @
41e4f7ea
...
...
@@ -256,36 +256,65 @@ __device__ __forceinline__ void BlockReduce(Pair<T>* sh_topk, int* maxid,
* 3. go to the second setp, until one thread's topk value is null;
* 4. go to the first setp, until get the topk value.
*/
template
<
typename
T
,
int
MaxLength
,
int
BlockSize
>
__global__
void
KeMatrixTopK
(
T
*
output
,
int
output_stride
,
int64_t
*
indices
,
const
T
*
src
,
int
lds
,
int
dim
,
int
k
)
{
const
T
*
src
,
int
lds
,
int
dim
,
int
k
,
int
grid_dim
,
int
num
)
{
__shared__
Pair
<
T
>
sh_topk
[
BlockSize
];
__shared__
int
maxid
[
BlockSize
/
2
];
const
int
tid
=
threadIdx
.
x
;
const
int
warp
=
threadIdx
.
x
/
32
;
output
+=
blockIdx
.
x
*
output_stride
;
indices
+=
blockIdx
.
x
*
k
;
Pair
<
T
>
topk
[
MaxLength
];
int
beam
=
MaxLength
;
Pair
<
T
>
max
;
bool
is_empty
=
false
;
bool
firststep
=
true
;
const
int
bid
=
blockIdx
.
x
;
for
(
int
i
=
bid
;
i
<
num
;
i
+=
grid_dim
)
{
output
+=
i
*
output_stride
;
indices
+=
i
*
k
;
Pair
<
T
>
topk
[
MaxLength
];
int
beam
=
MaxLength
;
Pair
<
T
>
max
;
bool
is_empty
=
false
;
bool
firststep
=
true
;
for
(
int
k
=
0
;
k
<
MaxLength
;
k
++
)
{
topk
[
k
].
set
(
-
INFINITY
,
-
1
);
}
while
(
k
)
{
ThreadGetTopK
<
T
,
MaxLength
,
BlockSize
>
(
topk
,
&
beam
,
k
,
src
+
i
*
lds
,
&
firststep
,
&
is_empty
,
&
max
,
dim
,
tid
);
for
(
int
k
=
0
;
k
<
MaxLength
;
k
++
)
{
topk
[
k
].
set
(
-
INFINITY
,
-
1
);
sh_topk
[
tid
]
=
topk
[
0
];
BlockReduce
<
T
,
MaxLength
,
BlockSize
>
(
sh_topk
,
maxid
,
topk
,
&
output
,
&
indices
,
&
beam
,
&
k
,
tid
,
warp
);
}
}
while
(
k
)
{
ThreadGetTopK
<
T
,
MaxLength
,
BlockSize
>
(
topk
,
&
beam
,
k
,
src
+
blockIdx
.
x
*
lds
,
&
firststep
,
&
is_empty
,
&
max
,
dim
,
tid
);
sh_topk
[
tid
]
=
topk
[
0
];
BlockReduce
<
T
,
MaxLength
,
BlockSize
>
(
sh_topk
,
maxid
,
topk
,
&
output
,
&
indices
,
&
beam
,
&
k
,
tid
,
warp
);
}
inline
static
int
GetDesiredBlockDim
(
int
dim
)
{
if
(
dim
>
128
)
{
return
256
;
}
else
if
(
dim
>
64
)
{
return
128
;
}
else
if
(
dim
>
32
)
{
return
64
;
}
else
{
return
32
;
}
}
#define FIXED_BLOCK_DIM_BASE(dim, ...) \
case (dim): { \
constexpr auto kBlockDim = (dim); \
__VA_ARGS__; \
} break
#define FIXED_BLOCK_DIM(...) \
FIXED_BLOCK_DIM_BASE(256, ##__VA_ARGS__); \
FIXED_BLOCK_DIM_BASE(128, ##__VA_ARGS__); \
FIXED_BLOCK_DIM_BASE(64, ##__VA_ARGS__); \
FIXED_BLOCK_DIM_BASE(32, ##__VA_ARGS__)
template
<
typename
T
>
class
TopkOpCUDAKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
...
...
@@ -310,18 +339,26 @@ class TopkOpCUDAKernel : public framework::OpKernel<T> {
// NOTE: pass lds and dim same to input width.
// NOTE: old matrix implementation of stride is different to eigen.
// TODO(typhoonzero): refine this kernel.
dim3
threads
(
256
,
1
);
dim3
grid
(
input_height
,
1
);
KeMatrixTopK
<
T
,
5
,
256
><<<
grid
,
threads
,
0
,
reinterpret_cast
<
const
platform
::
CUDADeviceContext
&>
(
ctx
.
device_context
())
.
stream
()
>>>
(
output_data
,
output
->
dims
()[
1
],
indices_data
,
input_data
,
input_width
,
input_width
,
static_cast
<
int
>
(
k
));
const
int
kMaxHeight
=
2048
;
int
gridx
=
input_height
<
kMaxHeight
?
input_height
:
kMaxHeight
;
auto
&
dev_ctx
=
ctx
.
cuda_device_context
();
switch
(
GetDesiredBlockDim
(
input_width
))
{
FIXED_BLOCK_DIM
(
KeMatrixTopK
<
T
,
5
,
kBlockDim
><<<
gridx
,
kBlockDim
,
0
,
dev_ctx
.
stream
()
>>>
(
output_data
,
output
->
dims
()[
1
],
indices_data
,
input_data
,
input_width
,
input_width
,
static_cast
<
int
>
(
k
),
gridx
,
input_height
));
default:
PADDLE_THROW
(
"Error"
);
}
}
};
#undef FIXED_BLOCK_DIM_BASE
#undef FIXED_BLOCK_DIM
}
// namespace operators
}
// namespace paddle
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录