Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
weixin_41840029
PaddleOCR
提交
0343756e
P
PaddleOCR
项目概览
weixin_41840029
/
PaddleOCR
与 Fork 源项目一致
Fork自
PaddlePaddle / PaddleOCR
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleOCR
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
0343756e
编写于
6月 03, 2021
作者:
littletomatodonkey
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix metric
上级
b48f7609
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
41 addition
and
38 deletion
+41
-38
configs/rec/ch_ppocr_v2.0/rec_chinese_lite_train_distillation_v2.1.yml
...h_ppocr_v2.0/rec_chinese_lite_train_distillation_v2.1.yml
+4
-4
ppocr/losses/basic_loss.py
ppocr/losses/basic_loss.py
+6
-15
ppocr/losses/distillation_loss.py
ppocr/losses/distillation_loss.py
+8
-5
ppocr/metrics/__init__.py
ppocr/metrics/__init__.py
+12
-9
ppocr/postprocess/rec_postprocess.py
ppocr/postprocess/rec_postprocess.py
+11
-5
未找到文件。
configs/rec/ch_ppocr_v2.0/rec_chinese_lite_train_distillation_v2.1.yml
浏览文件 @
0343756e
...
...
@@ -95,17 +95,17 @@ Loss:
model_name_pairs
:
-
[
"
Student"
,
"
Teacher"
]
key
:
backbone_out
PostProcess
:
name
:
DistillationCTCLabelDecode
model_name
:
"
Student"
model_name
:
[
"
Student"
,
"
Teacher"
]
key
:
head_out
Metric
:
name
:
RecMetric
name
:
DistillationMetric
base_metric_name
:
RecMetric
main_indicator
:
acc
key
:
"
Student"
Train
:
dataset
:
...
...
ppocr/losses/basic_loss.py
浏览文件 @
0343756e
...
...
@@ -22,9 +22,8 @@ from paddle.nn import SmoothL1Loss
class
CELoss
(
nn
.
Layer
):
def
__init__
(
self
,
name
=
"loss_ce"
,
epsilon
=
None
):
def
__init__
(
self
,
epsilon
=
None
):
super
().
__init__
()
self
.
name
=
name
if
epsilon
is
not
None
and
(
epsilon
<=
0
or
epsilon
>=
1
):
epsilon
=
None
self
.
epsilon
=
epsilon
...
...
@@ -52,9 +51,7 @@ class CELoss(nn.Layer):
else
:
soft_label
=
False
loss
=
F
.
cross_entropy
(
x
,
label
=
label
,
soft_label
=
soft_label
)
loss_dict
[
self
.
name
]
=
paddle
.
mean
(
loss
)
return
loss_dict
return
loss
class
DMLLoss
(
nn
.
Layer
):
...
...
@@ -62,11 +59,10 @@ class DMLLoss(nn.Layer):
DMLLoss
"""
def
__init__
(
self
,
act
=
None
,
name
=
"loss_dml"
):
def
__init__
(
self
,
act
=
None
):
super
().
__init__
()
if
act
is
not
None
:
assert
act
in
[
"softmax"
,
"sigmoid"
]
self
.
name
=
name
if
act
==
"softmax"
:
self
.
act
=
nn
.
Softmax
(
axis
=-
1
)
elif
act
==
"sigmoid"
:
...
...
@@ -75,7 +71,6 @@ class DMLLoss(nn.Layer):
self
.
act
=
None
def
forward
(
self
,
out1
,
out2
):
loss_dict
=
{}
if
self
.
act
is
not
None
:
out1
=
self
.
act
(
out1
)
out2
=
self
.
act
(
out2
)
...
...
@@ -85,18 +80,16 @@ class DMLLoss(nn.Layer):
loss
=
(
F
.
kl_div
(
log_out1
,
out2
,
reduction
=
'batchmean'
)
+
F
.
kl_div
(
log_out2
,
log_out1
,
reduction
=
'batchmean'
))
/
2.0
loss_dict
[
self
.
name
]
=
loss
return
loss_dict
return
loss
class
DistanceLoss
(
nn
.
Layer
):
"""
DistanceLoss:
mode: loss mode
name: loss key in the output dict
"""
def
__init__
(
self
,
mode
=
"l2"
,
name
=
"loss_dist"
,
**
kargs
):
def
__init__
(
self
,
mode
=
"l2"
,
**
kargs
):
super
().
__init__
()
assert
mode
in
[
"l1"
,
"l2"
,
"smooth_l1"
]
if
mode
==
"l1"
:
...
...
@@ -106,7 +99,5 @@ class DistanceLoss(nn.Layer):
elif
mode
==
"smooth_l1"
:
self
.
loss_func
=
nn
.
SmoothL1Loss
(
**
kargs
)
self
.
name
=
"{}_{}"
.
format
(
name
,
mode
)
def
forward
(
self
,
x
,
y
):
return
{
self
.
name
:
self
.
loss_func
(
x
,
y
)}
return
self
.
loss_func
(
x
,
y
)
ppocr/losses/distillation_loss.py
浏览文件 @
0343756e
...
...
@@ -26,10 +26,11 @@ class DistillationDMLLoss(DMLLoss):
def
__init__
(
self
,
model_name_pairs
=
[],
act
=
None
,
key
=
None
,
name
=
"loss_dml"
):
super
().
__init__
(
act
=
act
,
name
=
name
)
super
().
__init__
(
act
=
act
)
assert
isinstance
(
model_name_pairs
,
list
)
self
.
key
=
key
self
.
model_name_pairs
=
model_name_pairs
self
.
name
=
name
def
forward
(
self
,
predicts
,
batch
):
loss_dict
=
dict
()
...
...
@@ -42,8 +43,8 @@ class DistillationDMLLoss(DMLLoss):
loss
=
super
().
forward
(
out1
,
out2
)
if
isinstance
(
loss
,
dict
):
for
key
in
loss
:
loss_dict
[
"{}_{}_{}
"
.
format
(
self
.
name
,
key
,
idx
)]
=
loss
[
key
]
loss_dict
[
"{}_{}_{}
_{}"
.
format
(
key
,
pair
[
0
],
pair
[
1
],
idx
)]
=
loss
[
key
]
else
:
loss_dict
[
"{}_{}"
.
format
(
self
.
name
,
idx
)]
=
loss
return
loss_dict
...
...
@@ -82,10 +83,11 @@ class DistillationDistanceLoss(DistanceLoss):
key
=
None
,
name
=
"loss_distance"
,
**
kargs
):
super
().
__init__
(
mode
=
mode
,
name
=
name
,
**
kargs
)
super
().
__init__
(
mode
=
mode
,
**
kargs
)
assert
isinstance
(
model_name_pairs
,
list
)
self
.
key
=
key
self
.
model_name_pairs
=
model_name_pairs
self
.
name
=
name
+
"_l2"
def
forward
(
self
,
predicts
,
batch
):
loss_dict
=
dict
()
...
...
@@ -101,5 +103,6 @@ class DistillationDistanceLoss(DistanceLoss):
loss_dict
[
"{}_{}_{}"
.
format
(
self
.
name
,
key
,
idx
)]
=
loss
[
key
]
else
:
loss_dict
[
"{}_{}"
.
format
(
self
.
name
,
idx
)]
=
loss
loss_dict
[
"{}_{}_{}_{}"
.
format
(
self
.
name
,
pair
[
0
],
pair
[
1
],
idx
)]
=
loss
return
loss_dict
ppocr/metrics/__init__.py
浏览文件 @
0343756e
...
...
@@ -19,20 +19,23 @@ from __future__ import unicode_literals
import
copy
__all__
=
[
'build_metric'
]
__all__
=
[
"build_metric"
]
from
.det_metric
import
DetMetric
from
.rec_metric
import
RecMetric
from
.cls_metric
import
ClsMetric
from
.e2e_metric
import
E2EMetric
from
.distillation_metric
import
DistillationMetric
def
build_metric
(
config
):
from
.det_metric
import
DetMetric
from
.rec_metric
import
RecMetric
from
.cls_metric
import
ClsMetric
from
.e2e_metric
import
E2EMetric
support_dict
=
[
'DetMetric'
,
'RecMetric'
,
'ClsMetric'
,
'E2EMetric'
]
def
build_metric
(
config
):
support_dict
=
[
"DetMetric"
,
"RecMetric"
,
"ClsMetric"
,
"E2EMetric"
,
"DistillationMetric"
]
config
=
copy
.
deepcopy
(
config
)
module_name
=
config
.
pop
(
'name'
)
module_name
=
config
.
pop
(
"name"
)
assert
module_name
in
support_dict
,
Exception
(
'metric only support {}'
.
format
(
support_dict
))
"metric only support {}"
.
format
(
support_dict
))
module_class
=
eval
(
module_name
)(
**
config
)
return
module_class
ppocr/postprocess/rec_postprocess.py
浏览文件 @
0343756e
...
...
@@ -135,19 +135,25 @@ class DistillationCTCLabelDecode(CTCLabelDecode):
character_dict_path
=
None
,
character_type
=
'ch'
,
use_space_char
=
False
,
model_name
=
"student"
,
model_name
=
[
"student"
]
,
key
=
None
,
**
kwargs
):
super
(
DistillationCTCLabelDecode
,
self
).
__init__
(
character_dict_path
,
character_type
,
use_space_char
)
if
not
isinstance
(
model_name
,
list
):
model_name
=
[
model_name
]
self
.
model_name
=
model_name
self
.
key
=
key
def
__call__
(
self
,
preds
,
label
=
None
,
*
args
,
**
kwargs
):
pred
=
preds
[
self
.
model_name
]
if
self
.
key
is
not
None
:
pred
=
pred
[
self
.
key
]
return
super
().
__call__
(
pred
,
label
=
label
,
*
args
,
**
kwargs
)
output
=
dict
()
for
name
in
self
.
model_name
:
pred
=
preds
[
name
]
if
self
.
key
is
not
None
:
pred
=
pred
[
self
.
key
]
output
[
name
]
=
super
().
__call__
(
pred
,
label
=
label
,
*
args
,
**
kwargs
)
return
output
class
AttnLabelDecode
(
BaseRecLabelDecode
):
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录