Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleOCR
提交
0742f5c5
P
PaddleOCR
项目概览
PaddlePaddle
/
PaddleOCR
大约 1 年 前同步成功
通知
1528
Star
32962
Fork
6643
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
108
列表
看板
标记
里程碑
合并请求
7
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleOCR
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
108
Issue
108
列表
看板
标记
里程碑
合并请求
7
合并请求
7
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
0742f5c5
编写于
7月 07, 2021
作者:
L
LDOUBLEV
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix metric etc.al
上级
a7b32ca8
变更
8
隐藏空白更改
内联
并排
Showing
8 changed file
with
60 addition
and
33 deletion
+60
-33
configs/det/ch_ppocr_v2.1/ch_det_lite_train_cml_v2.1.yml
configs/det/ch_ppocr_v2.1/ch_det_lite_train_cml_v2.1.yml
+4
-4
ppocr/losses/basic_loss.py
ppocr/losses/basic_loss.py
+31
-6
ppocr/losses/combined_loss.py
ppocr/losses/combined_loss.py
+0
-2
ppocr/losses/distillation_loss.py
ppocr/losses/distillation_loss.py
+9
-9
ppocr/metrics/det_metric.py
ppocr/metrics/det_metric.py
+4
-0
ppocr/postprocess/db_postprocess.py
ppocr/postprocess/db_postprocess.py
+3
-6
ppocr/utils/save_load.py
ppocr/utils/save_load.py
+5
-4
tools/eval.py
tools/eval.py
+4
-2
未找到文件。
configs/det/ch_ppocr_v2.1/ch_det_lite_train_cml_v2.1.yml
浏览文件 @
0742f5c5
...
...
@@ -90,14 +90,14 @@ Loss:
-
[
"
Student"
,
"
Student2"
]
maps_name
:
"
thrink_maps"
weight
:
1.0
act
:
"
softmax"
# act: None
model_name_pairs
:
[
"
Student"
,
"
Student2"
]
key
:
maps
-
DistillationDBLoss
:
weight
:
1.0
model_name_list
:
[
"
Student"
,
"
Student2"
]
# key: maps
name
:
DBLoss
#
name: DBLoss
balance_loss
:
true
main_loss_type
:
DiceLoss
alpha
:
5
...
...
@@ -119,8 +119,8 @@ Optimizer:
PostProcess
:
name
:
DistillationDBPostProcess
model_name
:
[
"
Student"
,
"
Student2"
]
key
:
head_out
model_name
:
[
"
Student"
,
"
Student2"
,
"
Teacher"
]
# key: maps
thresh
:
0.3
box_thresh
:
0.6
max_candidates
:
1000
...
...
ppocr/losses/basic_loss.py
浏览文件 @
0742f5c5
...
...
@@ -54,6 +54,27 @@ class CELoss(nn.Layer):
return
loss
class
KLJSLoss
(
object
):
def
__init__
(
self
,
mode
=
'kl'
):
assert
mode
in
[
'kl'
,
'js'
,
'KL'
,
'JS'
],
"mode can only be one of ['kl', 'js', 'KL', 'JS']"
self
.
mode
=
mode
def
__call__
(
self
,
p1
,
p2
,
reduction
=
"mean"
):
loss
=
paddle
.
multiply
(
p2
,
paddle
.
log
(
(
p2
+
1e-5
)
/
(
p1
+
1e-5
)
+
1e-5
))
if
self
.
mode
.
lower
()
==
"js"
:
loss
+=
paddle
.
multiply
(
p1
,
paddle
.
log
((
p1
+
1e-5
)
/
(
p2
+
1e-5
)
+
1e-5
))
loss
*=
0.5
if
reduction
==
"mean"
:
loss
=
paddle
.
mean
(
loss
,
axis
=
[
1
,
2
])
elif
reduction
==
"none"
or
reduction
is
None
:
return
loss
else
:
loss
=
paddle
.
sum
(
loss
,
axis
=
[
1
,
2
])
return
loss
class
DMLLoss
(
nn
.
Layer
):
"""
DMLLoss
...
...
@@ -69,17 +90,21 @@ class DMLLoss(nn.Layer):
self
.
act
=
nn
.
Sigmoid
()
else
:
self
.
act
=
None
self
.
jskl_loss
=
KLJSLoss
(
mode
=
"js"
)
def
forward
(
self
,
out1
,
out2
):
if
self
.
act
is
not
None
:
out1
=
self
.
act
(
out1
)
out2
=
self
.
act
(
out2
)
log_out1
=
paddle
.
log
(
out1
)
log_out2
=
paddle
.
log
(
out2
)
loss
=
(
F
.
kl_div
(
log_out1
,
out2
,
reduction
=
'batchmean'
)
+
F
.
kl_div
(
log_out2
,
out1
,
reduction
=
'batchmean'
))
/
2.0
if
len
(
out1
.
shape
)
<
2
:
log_out1
=
paddle
.
log
(
out1
)
log_out2
=
paddle
.
log
(
out2
)
loss
=
(
F
.
kl_div
(
log_out1
,
out2
,
reduction
=
'batchmean'
)
+
F
.
kl_div
(
log_out2
,
out1
,
reduction
=
'batchmean'
))
/
2.0
else
:
loss
=
self
.
jskl_loss
(
out1
,
out2
)
return
loss
...
...
ppocr/losses/combined_loss.py
浏览文件 @
0742f5c5
...
...
@@ -55,7 +55,5 @@ class CombinedLoss(nn.Layer):
loss_all
+=
loss
[
key
]
*
weight
else
:
loss_dict
[
"{}_{}"
.
format
(
key
,
idx
)]
=
loss
[
key
]
# loss[f"{key}_{idx}"] = loss[key]
loss_dict
.
update
(
loss
)
loss_dict
[
"loss"
]
=
loss_all
return
loss_dict
ppocr/losses/distillation_loss.py
浏览文件 @
0742f5c5
...
...
@@ -46,13 +46,13 @@ class DistillationDMLLoss(DMLLoss):
act
=
None
,
key
=
None
,
maps_name
=
None
,
name
=
"
loss_
dml"
):
name
=
"dml"
):
super
().
__init__
(
act
=
act
)
assert
isinstance
(
model_name_pairs
,
list
)
self
.
key
=
key
self
.
model_name_pairs
=
self
.
_check_model_name_pairs
(
model_name_pairs
)
self
.
name
=
name
self
.
maps_name
=
maps_name
self
.
maps_name
=
self
.
_check_maps_name
(
maps_name
)
def
_check_model_name_pairs
(
self
,
model_name_pairs
):
if
not
isinstance
(
model_name_pairs
,
list
):
...
...
@@ -76,11 +76,11 @@ class DistillationDMLLoss(DMLLoss):
new_outs
=
{}
for
k
in
self
.
maps_name
:
if
k
==
"thrink_maps"
:
new_outs
[
k
]
=
paddle
.
slice
(
outs
,
axes
=
[
1
],
starts
=
[
0
],
ends
=
[
1
])
new_outs
[
k
]
=
outs
[:,
0
,
:,
:]
elif
k
==
"threshold_maps"
:
new_outs
[
k
]
=
paddle
.
slice
(
outs
,
axes
=
[
1
],
starts
=
[
1
],
ends
=
[
2
])
new_outs
[
k
]
=
outs
[:,
1
,
:,
:]
elif
k
==
"binary_maps"
:
new_outs
[
k
]
=
paddle
.
slice
(
outs
,
axes
=
[
1
],
starts
=
[
2
],
ends
=
[
3
])
new_outs
[
k
]
=
outs
[:,
2
,
:,
:]
else
:
continue
return
new_outs
...
...
@@ -105,16 +105,16 @@ class DistillationDMLLoss(DMLLoss):
else
:
outs1
=
self
.
_slice_out
(
out1
)
outs2
=
self
.
_slice_out
(
out2
)
for
k
in
outs1
.
keys
(
):
for
_c
,
k
in
enumerate
(
outs1
.
keys
()
):
loss
=
super
().
forward
(
outs1
[
k
],
outs2
[
k
])
if
isinstance
(
loss
,
dict
):
for
key
in
loss
:
loss_dict
[
"{}_{}_{}_{}_{}"
.
format
(
key
,
pair
[
0
],
pair
[
1
],
map_name
,
idx
)]
=
loss
[
key
]
else
:
loss_dict
[
"{}_{}_{}"
.
format
(
self
.
name
,
self
.
maps_name
,
loss_dict
[
"{}_{}_{}"
.
format
(
self
.
name
,
self
.
maps_name
[
_c
]
,
idx
)]
=
loss
loss_dict
=
_sum_loss
(
loss_dict
)
return
loss_dict
...
...
@@ -152,7 +152,7 @@ class DistillationDBLoss(DBLoss):
beta
=
10
,
ohem_ratio
=
3
,
eps
=
1e-6
,
name
=
"db
_loss
"
,
name
=
"db"
,
**
kwargs
):
super
().
__init__
()
self
.
model_name_list
=
model_name_list
...
...
ppocr/metrics/det_metric.py
浏览文件 @
0742f5c5
...
...
@@ -55,6 +55,10 @@ class DetMetric(object):
result
=
self
.
evaluator
.
evaluate_image
(
gt_info_list
,
det_info_list
)
self
.
results
.
append
(
result
)
metircs
=
self
.
evaluator
.
combine_results
(
self
.
results
)
self
.
reset
()
return
metircs
def
get_metric
(
self
):
"""
return metrics {
...
...
ppocr/postprocess/db_postprocess.py
浏览文件 @
0742f5c5
...
...
@@ -200,21 +200,18 @@ class DistillationDBPostProcess(DBPostProcess):
use_dilation
=
False
,
score_mode
=
"fast"
,
**
kwargs
):
super
(
DistillationDBPostProcess
,
self
).
__init__
(
thresh
,
box_thresh
,
max_candidates
,
unclip_ratio
,
use_dilation
,
score_mode
)
super
().
__init__
()
if
not
isinstance
(
model_name
,
list
):
model_name
=
[
model_name
]
self
.
model_name
=
model_name
self
.
key
=
key
def
forward
(
self
,
predicts
,
shape_list
):
def
__call__
(
self
,
predicts
,
shape_list
):
results
=
{}
for
name
in
self
.
model_name
:
pred
=
predicts
[
name
]
if
self
.
key
is
not
None
:
pred
=
pred
[
self
.
key
]
results
[
name
]
=
super
().
__call__
(
pred
,
shape_list
=
label
)
results
[
name
]
=
super
().
__call__
(
pred
,
shape_list
=
shape_list
)
return
results
ppocr/utils/save_load.py
浏览文件 @
0742f5c5
...
...
@@ -130,11 +130,12 @@ def load_pretrained_params(model, path):
for
k1
,
k2
in
zip
(
state_dict
.
keys
(),
params
.
keys
()):
if
list
(
state_dict
[
k1
].
shape
)
==
list
(
params
[
k2
].
shape
):
new_state_dict
[
k1
]
=
params
[
k2
]
else
:
print
(
f
"The shape of model params
{
k1
}
{
state_dict
[
k1
].
shape
}
not matched with loaded params
{
k2
}
{
params
[
k2
].
shape
}
!"
)
else
:
print
(
f
"The shape of model params
{
k1
}
{
state_dict
[
k1
].
shape
}
not matched with loaded params
{
k2
}
{
params
[
k2
].
shape
}
!"
)
model
.
set_state_dict
(
new_state_dict
)
print
(
f
"load pretrain successful from
{
path
}
"
)
return
True
def
save_model
(
model
,
...
...
tools/eval.py
浏览文件 @
0742f5c5
...
...
@@ -55,8 +55,10 @@ def main():
model
=
build_model
(
config
[
'Architecture'
])
use_srn
=
config
[
'Architecture'
][
'algorithm'
]
==
"SRN"
model_type
=
config
[
'Architecture'
][
'model_type'
]
if
"model_type"
in
config
[
'Architecture'
].
keys
():
model_type
=
config
[
'Architecture'
][
'model_type'
]
else
:
model_type
=
None
best_model_dict
=
init_model
(
config
,
model
)
if
len
(
best_model_dict
):
logger
.
info
(
'metric in ckpt ***************'
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录