Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleClas
提交
634951f0
P
PaddleClas
项目概览
PaddlePaddle
/
PaddleClas
大约 1 年 前同步成功
通知
115
Star
4999
Fork
1114
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
19
列表
看板
标记
里程碑
合并请求
6
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleClas
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
19
Issue
19
列表
看板
标记
里程碑
合并请求
6
合并请求
6
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
634951f0
编写于
9月 15, 2020
作者:
littletomatodonkey
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix loss
上级
b8a7d186
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
20 addition
and
17 deletion
+20
-17
ppcls/modeling/loss.py
ppcls/modeling/loss.py
+20
-17
未找到文件。
ppcls/modeling/loss.py
浏览文件 @
634951f0
...
...
@@ -12,7 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import
paddle.fluid
as
fluid
import
paddle
import
paddle.nn.functional
as
F
__all__
=
[
'CELoss'
,
'MixCELoss'
,
'GoogLeNetLoss'
,
'JSDivLoss'
]
...
...
@@ -34,35 +35,37 @@ class Loss(object):
def
_labelsmoothing
(
self
,
target
):
if
target
.
shape
[
-
1
]
!=
self
.
_class_dim
:
one_hot_target
=
fluid
.
one_hot
(
input
=
target
,
depth
=
self
.
_class_dim
)
one_hot_target
=
F
.
one_hot
(
target
,
self
.
_class_dim
)
else
:
one_hot_target
=
target
soft_target
=
fluid
.
layers
.
label_smooth
(
label
=
one_hot_target
,
epsilon
=
self
.
_epsilon
,
dtype
=
"float32"
)
soft_target
=
fluid
.
layers
.
reshape
(
soft_target
,
shape
=
[
-
1
,
self
.
_class_dim
])
soft_target
=
F
.
label_smooth
(
one_hot_target
,
epsilon
=
self
.
_epsilon
,
dtype
=
"float32"
)
soft_target
=
paddle
.
reshape
(
soft_target
,
shape
=
[
-
1
,
self
.
_class_dim
])
return
soft_target
def
_crossentropy
(
self
,
input
,
target
):
if
self
.
_label_smoothing
:
target
=
self
.
_labelsmoothing
(
target
)
softmax_out
=
fluid
.
layers
.
softmax
(
input
,
use_cudnn
=
False
)
cost
=
fluid
.
layers
.
cross_entropy
(
input
=
softmax_out
,
label
=
target
,
soft_label
=
self
.
_label_smoothing
)
avg_cost
=
fluid
.
layers
.
mean
(
cost
)
input
=
-
F
.
log_softmax
(
input
,
axis
=-
1
)
log_probs
=
-
F
.
log_softmax
(
input
,
axis
=-
1
)
cost
=
paddle
.
reduce_sum
(
target
*
log_probs
,
dim
=-
1
)
else
:
# softmax_out = F.softmax(input)
cost
=
F
.
cross_entropy
(
input
=
input
,
label
=
target
)
avg_cost
=
paddle
.
mean
(
cost
)
return
avg_cost
def
_kldiv
(
self
,
input
,
target
):
cost
=
target
*
fluid
.
layers
.
log
(
target
/
input
)
*
self
.
_class_dim
cost
=
fluid
.
layers
.
sum
(
cost
)
cost
=
target
*
F
.
log
(
target
/
input
)
*
self
.
_class_dim
cost
=
paddle
.
sum
(
cost
)
return
cost
def
_jsdiv
(
self
,
input
,
target
):
input
=
fluid
.
layers
.
softmax
(
input
,
use_cudnn
=
False
)
target
=
fluid
.
layers
.
softmax
(
target
,
use_cudnn
=
False
)
input
=
F
.
softmax
(
input
)
target
=
F
.
softmax
(
target
)
cost
=
self
.
_kldiv
(
input
,
target
)
+
self
.
_kldiv
(
target
,
input
)
cost
=
cost
/
2
avg_cost
=
fluid
.
layers
.
mean
(
cost
)
avg_cost
=
paddle
.
mean
(
cost
)
return
avg_cost
def
__call__
(
self
,
input
,
target
):
...
...
@@ -94,7 +97,7 @@ class MixCELoss(Loss):
cost0
=
self
.
_crossentropy
(
input
,
target0
)
cost1
=
self
.
_crossentropy
(
input
,
target1
)
cost
=
lam
*
cost0
+
(
1.0
-
lam
)
*
cost1
avg_cost
=
fluid
.
layers
.
mean
(
cost
)
avg_cost
=
paddle
.
mean
(
cost
)
return
avg_cost
...
...
@@ -111,7 +114,7 @@ class GoogLeNetLoss(Loss):
cost1
=
self
.
_crossentropy
(
input1
,
target
)
cost2
=
self
.
_crossentropy
(
input2
,
target
)
cost
=
cost0
+
0.3
*
cost1
+
0.3
*
cost2
avg_cost
=
fluid
.
layers
.
mean
(
cost
)
avg_cost
=
paddle
.
mean
(
cost
)
return
avg_cost
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录