Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
8f2656ef
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2299
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
8f2656ef
编写于
11月 16, 2020
作者:
W
wawltor
提交者:
GitHub
11月 16, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix the gradient bug for the topk v2
fix the gradient bug for the topk v2
上级
a972c33f
变更
2
显示空白变更内容
内联
并排
Showing
2 changed file
with
25 addition
and
19 deletion
+25
-19
paddle/fluid/operators/top_k_function_cuda.h
paddle/fluid/operators/top_k_function_cuda.h
+7
-5
python/paddle/fluid/tests/unittests/test_top_k_v2_op.py
python/paddle/fluid/tests/unittests/test_top_k_v2_op.py
+18
-14
未找到文件。
paddle/fluid/operators/top_k_function_cuda.h
浏览文件 @
8f2656ef
...
@@ -335,6 +335,7 @@ __global__ void AssignGrad(T* x_grad, const int64_t* indices, const T* out_grad,
...
@@ -335,6 +335,7 @@ __global__ void AssignGrad(T* x_grad, const int64_t* indices, const T* out_grad,
for
(
size_t
j
=
0
;
j
<
cols
;
++
j
)
{
for
(
size_t
j
=
0
;
j
<
cols
;
++
j
)
{
x_grad
[
i
*
cols
+
j
]
=
0
;
x_grad
[
i
*
cols
+
j
]
=
0
;
}
}
__syncthreads
();
for
(
size_t
j
=
0
;
j
<
k
;
++
j
)
{
for
(
size_t
j
=
0
;
j
<
k
;
++
j
)
{
size_t
idx
=
indices
[
i
*
k
+
j
];
size_t
idx
=
indices
[
i
*
k
+
j
];
x_grad
[
i
*
cols
+
idx
]
=
out_grad
[
i
*
k
+
j
];
x_grad
[
i
*
cols
+
idx
]
=
out_grad
[
i
*
k
+
j
];
...
@@ -349,15 +350,16 @@ __global__ void AssignGradWithAxis(const T* grad_out, const int64_t* indices,
...
@@ -349,15 +350,16 @@ __global__ void AssignGradWithAxis(const T* grad_out, const int64_t* indices,
int
raw_height
,
int
k
)
{
int
raw_height
,
int
k
)
{
// raw_height is the length of topk axis
// raw_height is the length of topk axis
for
(
int
i
=
blockIdx
.
x
;
i
<
pre
;
i
+=
gridDim
.
x
)
{
for
(
int
i
=
blockIdx
.
x
;
i
<
pre
;
i
+=
gridDim
.
x
)
{
const
int
&
base_index
=
i
*
post
*
k
;
int
base_index
=
i
*
post
*
k
;
const
int
&
base_grad
=
i
*
post
*
raw_height
;
int
base_grad
=
i
*
post
*
raw_height
;
for
(
int
j
=
threadIdx
.
x
;
j
<
raw_height
*
post
;
j
+=
blockDim
.
x
)
{
for
(
int
j
=
threadIdx
.
x
;
j
<
raw_height
*
post
;
j
+=
blockDim
.
x
)
{
grad_in
[
base_grad
+
j
]
=
static_cast
<
T
>
(
0
);
grad_in
[
base_grad
+
j
]
=
static_cast
<
T
>
(
0
);
}
}
__syncthreads
();
for
(
int
j
=
threadIdx
.
x
;
j
<
k
*
post
;
j
+=
blockDim
.
x
)
{
for
(
int
j
=
threadIdx
.
x
;
j
<
k
*
post
;
j
+=
blockDim
.
x
)
{
const
int64_t
idx_ij
=
indices
[
base_index
+
j
];
int64_t
idx_ij
=
indices
[
base_index
+
j
];
const
int64_t
in_ij
=
base_grad
+
(
idx_ij
*
post
)
+
(
j
%
post
);
int64_t
in_ij
=
base_grad
+
(
idx_ij
*
post
)
+
(
j
%
post
);
grad_in
[
in_ij
]
=
grad_out
[
idx_i
j
];
grad_in
[
in_ij
]
=
grad_out
[
base_index
+
j
];
}
}
}
}
}
}
...
...
python/paddle/fluid/tests/unittests/test_top_k_v2_op.py
浏览文件 @
8f2656ef
...
@@ -64,34 +64,38 @@ class TestTopkOp(OpTest):
...
@@ -64,34 +64,38 @@ class TestTopkOp(OpTest):
class
TestTopkOp1
(
TestTopkOp
):
class
TestTopkOp1
(
TestTopkOp
):
def
init_args
(
self
):
self
.
k
=
3
self
.
axis
=
0
self
.
largest
=
True
class
TestTopkOp2
(
TestTopkOp
):
def
init_args
(
self
):
def
init_args
(
self
):
self
.
k
=
3
self
.
k
=
3
self
.
axis
=
0
self
.
axis
=
0
self
.
largest
=
False
self
.
largest
=
False
class
TestTopkOp
3
(
TestTopkOp
):
class
TestTopkOp
2
(
TestTopkOp
):
def
init_args
(
self
):
def
init_args
(
self
):
self
.
k
=
4
self
.
k
=
4
self
.
axis
=
0
self
.
axis
=
0
self
.
largest
=
False
self
.
largest
=
False
class
TestTopkOp
4
(
TestTopkOp
):
class
TestTopkOp
3
(
OpTest
):
def
init_args
(
self
):
def
init_args
(
self
):
self
.
k
=
4
self
.
k
=
6
self
.
axis
=
0
self
.
axis
=
1
self
.
largest
=
Fals
e
self
.
largest
=
Tru
e
def
setUp
(
self
):
self
.
op_type
=
"top_k_v2"
self
.
dtype
=
np
.
float64
self
.
input_data
=
np
.
random
.
rand
(
16
,
100
)
self
.
init_args
()
self
.
inputs
=
{
'X'
:
self
.
input_data
}
self
.
attrs
=
{
'k'
:
self
.
k
,
'axis'
:
self
.
axis
,
'largest'
:
self
.
largest
}
output
,
indices
=
numpy_topk
(
self
.
input_data
,
axis
=
self
.
axis
,
k
=
self
.
k
,
largest
=
self
.
largest
)
self
.
outputs
=
{
'Out'
:
output
,
'Indices'
:
indices
}
class
TestTopkOp5
(
TestTopkOp
):
class
TestTopkOp4
(
TestTopkOp
):
def
init_args
(
self
):
def
init_args
(
self
):
self
.
k
=
3
self
.
k
=
3
self
.
axis
=
1
self
.
axis
=
1
...
@@ -109,7 +113,7 @@ class TestTopkOp5(TestTopkOp):
...
@@ -109,7 +113,7 @@ class TestTopkOp5(TestTopkOp):
self
.
outputs
=
{
'Out'
:
output
,
'Indices'
:
indices
}
self
.
outputs
=
{
'Out'
:
output
,
'Indices'
:
indices
}
class
TestTopkOp
6
(
TestTopkOp
):
class
TestTopkOp
5
(
TestTopkOp
):
def
init_args
(
self
):
def
init_args
(
self
):
self
.
k
=
3
self
.
k
=
3
self
.
axis
=
1
self
.
axis
=
1
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录