Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
d3cc7ac3
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
d3cc7ac3
编写于
10月 30, 2017
作者:
F
fengjiayi
提交者:
GitHub
10月 30, 2017
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Fix top k op GPU code (#5221)
* Fix Type error * Fix error * Fix top_k_op GPU code data type
上级
b9056bb0
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
6 addition
and
6 deletion
+6
-6
paddle/operators/top_k_op.cu
paddle/operators/top_k_op.cu
+6
-6
未找到文件。
paddle/operators/top_k_op.cu
浏览文件 @
d3cc7ac3
...
...
@@ -23,9 +23,9 @@ using Tensor = framework::Tensor;
template
<
typename
T
>
struct
Pair
{
__device__
__forceinline__
Pair
()
{}
__device__
__forceinline__
Pair
(
T
value
,
int
id
)
:
v
(
value
),
id
(
id
)
{}
__device__
__forceinline__
Pair
(
T
value
,
int
64_t
id
)
:
v
(
value
),
id
(
id
)
{}
__device__
__forceinline__
void
set
(
T
value
,
int
id
)
{
__device__
__forceinline__
void
set
(
T
value
,
int
64_t
id
)
{
v
=
value
;
id
=
id
;
}
...
...
@@ -48,7 +48,7 @@ struct Pair {
}
T
v
;
int
id
;
int
64_t
id
;
};
template
<
typename
T
>
...
...
@@ -197,7 +197,7 @@ __device__ __forceinline__ void ThreadGetTopK(Pair<T> topk[], int& beam,
template
<
typename
T
,
int
MaxLength
,
int
BlockSize
>
__device__
__forceinline__
void
BlockReduce
(
Pair
<
T
>*
sh_topk
,
int
*
maxid
,
Pair
<
T
>
topk
[],
T
**
topVal
,
int
**
topIds
,
int
&
beam
,
int
&
k
,
int
64_t
**
topIds
,
int
&
beam
,
int
&
k
,
const
int
tid
,
const
int
warp
)
{
while
(
true
)
{
__syncthreads
();
...
...
@@ -249,7 +249,7 @@ __device__ __forceinline__ void BlockReduce(Pair<T>* sh_topk, int* maxid,
* 4. go to the first setp, until get the topk value.
*/
template
<
typename
T
,
int
MaxLength
,
int
BlockSize
>
__global__
void
KeMatrixTopK
(
T
*
output
,
int
output_stride
,
int
*
indices
,
__global__
void
KeMatrixTopK
(
T
*
output
,
int
output_stride
,
int
64_t
*
indices
,
const
T
*
src
,
int
lds
,
int
dim
,
int
k
)
{
__shared__
Pair
<
T
>
sh_topk
[
BlockSize
];
__shared__
int
maxid
[
BlockSize
/
2
];
...
...
@@ -293,7 +293,7 @@ class TopkOpCUDAKernel : public framework::OpKernel<T> {
T
*
output_data
=
output
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
// FIXME(typhoonzero): data is always converted to type T?
int
*
indices_data
=
indices
->
mutable_data
<
in
t
>
(
ctx
.
GetPlace
());
int
64_t
*
indices_data
=
indices
->
mutable_data
<
int64_
t
>
(
ctx
.
GetPlace
());
size_t
input_height
=
input
->
dims
()[
0
];
size_t
input_width
=
input
->
dims
()[
1
];
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录