Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
coolalex776
Paddle
提交
7689b6aa
P
Paddle
项目概览
coolalex776
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
未验证
提交
7689b6aa
编写于
12月 23, 2019
作者:
G
Guo Sheng
提交者:
GitHub
12月 23, 2019
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Fix default label dim of label_smooth_op. test=develop (#21862)
上级
13e4756f
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
23 addition
and
5 deletion
+23
-5
paddle/fluid/operators/label_smooth_op.cc
paddle/fluid/operators/label_smooth_op.cc
+1
-1
paddle/fluid/operators/label_smooth_op.cu
paddle/fluid/operators/label_smooth_op.cu
+2
-2
paddle/fluid/operators/label_smooth_op.h
paddle/fluid/operators/label_smooth_op.h
+2
-2
python/paddle/fluid/tests/unittests/test_label_smooth_op.py
python/paddle/fluid/tests/unittests/test_label_smooth_op.py
+18
-0
未找到文件。
paddle/fluid/operators/label_smooth_op.cc
浏览文件 @
7689b6aa
...
...
@@ -37,7 +37,7 @@ class LabelSmoothOp : public framework::OperatorWithKernel {
auto
noise_dims
=
ctx
->
GetInputDim
(
"PriorDist"
);
auto
noise_numel
=
paddle
::
framework
::
product
(
noise_dims
);
PADDLE_ENFORCE
(
in_dims
[
1
]
==
noise_numel
,
in_dims
[
in_dims
.
size
()
-
1
]
==
noise_numel
,
"The number of elements in Input(PriorDist) must be equal to the "
"dimension of each label."
);
}
...
...
paddle/fluid/operators/label_smooth_op.cu
浏览文件 @
7689b6aa
...
...
@@ -34,7 +34,7 @@ __global__ void LabelSmoothRunDistKernel(const int N, const float epsilon,
const
T
*
dist_data
,
T
*
dst
)
{
int
idx
=
blockDim
.
x
*
blockIdx
.
x
+
threadIdx
.
x
;
for
(;
idx
<
N
;
idx
+=
blockDim
.
x
*
gridDim
.
x
)
{
int
dist_idx
=
idx
-
(
idx
/
dist_numel
)
*
dist_numel
;
int
dist_idx
=
idx
%
dist_numel
;
dst
[
idx
]
=
static_cast
<
T
>
(
1
-
epsilon
)
*
src
[
idx
]
+
static_cast
<
T
>
(
epsilon
)
*
dist_data
[
dist_idx
];
}
...
...
@@ -56,7 +56,7 @@ class LabelSmoothGPUKernel : public framework::OpKernel<T> {
auto
*
out_t
=
ctx
.
Output
<
framework
::
LoDTensor
>
(
"Out"
);
auto
*
in_t
=
ctx
.
Input
<
framework
::
LoDTensor
>
(
"X"
);
auto
*
dist_t
=
ctx
.
Input
<
framework
::
Tensor
>
(
"PriorDist"
);
auto
label_dim
=
in_t
->
dims
()[
1
];
auto
label_dim
=
in_t
->
dims
()[
in_t
->
dims
().
size
()
-
1
];
auto
epsilon
=
ctx
.
Attr
<
float
>
(
"epsilon"
);
auto
&
dev
=
*
ctx
.
template
device_context
<
DeviceContext
>().
eigen_device
();
auto
size_prob
=
in_t
->
numel
();
...
...
paddle/fluid/operators/label_smooth_op.h
浏览文件 @
7689b6aa
...
...
@@ -27,7 +27,7 @@ class LabelSmoothKernel : public framework::OpKernel<T> {
auto
*
out_t
=
ctx
.
Output
<
framework
::
LoDTensor
>
(
"Out"
);
auto
*
in_t
=
ctx
.
Input
<
framework
::
LoDTensor
>
(
"X"
);
auto
*
dist_t
=
ctx
.
Input
<
framework
::
Tensor
>
(
"PriorDist"
);
auto
label_dim
=
in_t
->
dims
()[
1
];
auto
label_dim
=
in_t
->
dims
()[
in_t
->
dims
().
size
()
-
1
];
out_t
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
epsilon
=
ctx
.
Attr
<
float
>
(
"epsilon"
);
...
...
@@ -39,7 +39,7 @@ class LabelSmoothKernel : public framework::OpKernel<T> {
out
.
device
(
dev
)
=
static_cast
<
T
>
(
1
-
epsilon
)
*
in
+
static_cast
<
T
>
(
epsilon
)
*
dist
.
broadcast
(
Eigen
::
DSizes
<
int
,
1
>
(
in_t
->
numel
()));
dist
.
broadcast
(
Eigen
::
DSizes
<
int
,
1
>
(
in_t
->
numel
()
/
label_dim
));
}
else
{
out
.
device
(
dev
)
=
static_cast
<
T
>
(
1
-
epsilon
)
*
in
+
static_cast
<
T
>
(
epsilon
/
label_dim
);
...
...
python/paddle/fluid/tests/unittests/test_label_smooth_op.py
浏览文件 @
7689b6aa
...
...
@@ -53,5 +53,23 @@ class TestLabelSmoothOpWithPriorDist(TestLabelSmoothOp):
self
.
outputs
=
{
'Out'
:
smoothed_label
}
class
TestLabelSmoothOp3D
(
TestLabelSmoothOp
):
def
setUp
(
self
):
super
(
TestLabelSmoothOp3D
,
self
).
setUp
()
self
.
inputs
[
'X'
]
=
self
.
inputs
[
'X'
].
reshape
(
[
2
,
-
1
,
self
.
inputs
[
'X'
].
shape
[
-
1
]])
self
.
outputs
[
'Out'
]
=
self
.
outputs
[
'Out'
].
reshape
(
self
.
inputs
[
'X'
]
.
shape
)
class
TestLabelSmoothOpWithPriorDist3D
(
TestLabelSmoothOpWithPriorDist
):
def
setUp
(
self
):
super
(
TestLabelSmoothOpWithPriorDist3D
,
self
).
setUp
()
self
.
inputs
[
'X'
]
=
self
.
inputs
[
'X'
].
reshape
(
[
2
,
-
1
,
self
.
inputs
[
'X'
].
shape
[
-
1
]])
self
.
outputs
[
'Out'
]
=
self
.
outputs
[
'Out'
].
reshape
(
self
.
inputs
[
'X'
]
.
shape
)
if
__name__
==
'__main__'
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录