Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleOCR
提交
185d1e1f
P
PaddleOCR
项目概览
PaddlePaddle
/
PaddleOCR
大约 1 年 前同步成功
通知
1528
Star
32962
Fork
6643
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
108
列表
看板
标记
里程碑
合并请求
7
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleOCR
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
108
Issue
108
列表
看板
标记
里程碑
合并请求
7
合并请求
7
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
185d1e1f
编写于
7月 07, 2021
作者:
L
LDOUBLEV
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix bug
上级
a91bbd74
变更
7
隐藏空白更改
内联
并排
Showing
7 changed file
with
89 addition
and
12 deletion
+89
-12
configs/det/ch_ppocr_v2.1/ch_det_lite_train_distill_v2.1.yml
configs/det/ch_ppocr_v2.1/ch_det_lite_train_distill_v2.1.yml
+13
-5
ppocr/losses/distillation_loss.py
ppocr/losses/distillation_loss.py
+7
-2
ppocr/modeling/architectures/distillation_model.py
ppocr/modeling/architectures/distillation_model.py
+2
-2
ppocr/postprocess/__init__.py
ppocr/postprocess/__init__.py
+2
-2
ppocr/postprocess/db_postprocess.py
ppocr/postprocess/db_postprocess.py
+41
-0
ppocr/utils/save_load.py
ppocr/utils/save_load.py
+20
-0
tools/program.py
tools/program.py
+4
-1
未找到文件。
configs/det/ch_ppocr_v2.1/ch_det_lite_train_distill_v2.1.yml
浏览文件 @
185d1e1f
...
@@ -20,7 +20,7 @@ Architecture:
...
@@ -20,7 +20,7 @@ Architecture:
algorithm
:
Distillation
algorithm
:
Distillation
Models
:
Models
:
Student
:
Student
:
pretrained
:
pretrained
:
./pretrain_models/MobileNetV3_large_x0_5_pretrained
freeze_params
:
false
freeze_params
:
false
return_all_feats
:
false
return_all_feats
:
false
model_type
:
det
model_type
:
det
...
@@ -37,7 +37,7 @@ Architecture:
...
@@ -37,7 +37,7 @@ Architecture:
name
:
DBHead
name
:
DBHead
k
:
50
k
:
50
Student2
:
Student2
:
pretrained
:
pretrained
:
./pretrain_models/MobileNetV3_large_x0_5_pretrained
freeze_params
:
false
freeze_params
:
false
return_all_feats
:
false
return_all_feats
:
false
model_type
:
det
model_type
:
det
...
@@ -55,6 +55,9 @@ Architecture:
...
@@ -55,6 +55,9 @@ Architecture:
name
:
DBHead
name
:
DBHead
k
:
50
k
:
50
Teacher
:
Teacher
:
pretrained
:
./pretrain_models/ch_ppocr_server_v2.0_det_train/best_accuracy
freeze_params
:
true
return_all_feats
:
false
model_type
:
det
model_type
:
det
algorithm
:
DB
algorithm
:
DB
Transform
:
Transform
:
...
@@ -73,7 +76,9 @@ Loss:
...
@@ -73,7 +76,9 @@ Loss:
loss_config_list
:
loss_config_list
:
-
DistillationDilaDBLoss
:
-
DistillationDilaDBLoss
:
weight
:
1.0
weight
:
1.0
model_name_list
:
[
"
Student"
,
"
Student2"
,
"
Teacher"
]
model_name_pairs
:
-
[
"
Student"
,
"
Teacher"
]
-
[
"
Student2"
,
"
Teacher"
]
key
:
maps
key
:
maps
balance_loss
:
true
balance_loss
:
true
main_loss_type
:
DiceLoss
main_loss_type
:
DiceLoss
...
@@ -81,13 +86,16 @@ Loss:
...
@@ -81,13 +86,16 @@ Loss:
beta
:
10
beta
:
10
ohem_ratio
:
3
ohem_ratio
:
3
-
DistillationDMLLoss
:
-
DistillationDMLLoss
:
model_name_pairs
:
-
[
"
Student"
,
"
Student2"
]
maps_name
:
[
"
thrink_maps"
]
maps_name
:
[
"
thrink_maps"
]
weight
:
1.0
weight
:
1.0
act
:
"
softmax"
act
:
"
softmax"
model_name_pairs
:
[
"
Student"
,
"
Student2"
]
model_name_pairs
:
[
"
Student"
,
"
Student2"
]
key
:
maps
key
:
maps
-
DistillationDBLoss
:
-
DistillationDBLoss
:
model_name_list
:
[
"
Student"
,
"
Teacher"
]
weight
:
1.0
model_name_list
:
[
"
Student"
,
"
Student2"
]
key
:
maps
key
:
maps
name
:
DBLoss
name
:
DBLoss
balance_loss
:
true
balance_loss
:
true
...
@@ -110,7 +118,7 @@ Optimizer:
...
@@ -110,7 +118,7 @@ Optimizer:
factor
:
0
factor
:
0
PostProcess
:
PostProcess
:
name
:
Distillation
CTDBPostProcessCLabelDecode
name
:
Distillation
DBPostProcess
model_name
:
[
"
Student"
,
"
Student2"
]
model_name
:
[
"
Student"
,
"
Student2"
]
key
:
head_out
key
:
head_out
thresh
:
0.3
thresh
:
0.3
...
...
ppocr/losses/distillation_loss.py
浏览文件 @
185d1e1f
...
@@ -14,6 +14,8 @@
...
@@ -14,6 +14,8 @@
import
paddle
import
paddle
import
paddle.nn
as
nn
import
paddle.nn
as
nn
import
numpy
as
np
import
cv2
from
.rec_ctc_loss
import
CTCLoss
from
.rec_ctc_loss
import
CTCLoss
from
.basic_loss
import
DMLLoss
from
.basic_loss
import
DMLLoss
...
@@ -22,6 +24,7 @@ from .det_db_loss import DBLoss
...
@@ -22,6 +24,7 @@ from .det_db_loss import DBLoss
from
.det_basic_loss
import
BalanceLoss
,
MaskL1Loss
,
DiceLoss
from
.det_basic_loss
import
BalanceLoss
,
MaskL1Loss
,
DiceLoss
def
_sum_loss
(
loss_dict
):
def
_sum_loss
(
loss_dict
):
if
"loss"
in
loss_dict
.
keys
():
if
"loss"
in
loss_dict
.
keys
():
return
loss_dict
return
loss_dict
...
@@ -50,7 +53,7 @@ class DistillationDMLLoss(DMLLoss):
...
@@ -50,7 +53,7 @@ class DistillationDMLLoss(DMLLoss):
self
.
key
=
key
self
.
key
=
key
self
.
model_name_pairs
=
model_name_pairs
self
.
model_name_pairs
=
model_name_pairs
self
.
name
=
name
self
.
name
=
name
self
.
maps_name
=
self
.
maps_name
self
.
maps_name
=
maps_name
def
_check_maps_name
(
self
,
maps_name
):
def
_check_maps_name
(
self
,
maps_name
):
if
maps_name
is
None
:
if
maps_name
is
None
:
...
@@ -172,6 +175,7 @@ class DistillationDBLoss(DBLoss):
...
@@ -172,6 +175,7 @@ class DistillationDBLoss(DBLoss):
class
DistillationDilaDBLoss
(
DBLoss
):
class
DistillationDilaDBLoss
(
DBLoss
):
def
__init__
(
self
,
def
__init__
(
self
,
model_name_pairs
=
[],
model_name_pairs
=
[],
key
=
None
,
balance_loss
=
True
,
balance_loss
=
True
,
main_loss_type
=
'DiceLoss'
,
main_loss_type
=
'DiceLoss'
,
alpha
=
5
,
alpha
=
5
,
...
@@ -182,6 +186,7 @@ class DistillationDilaDBLoss(DBLoss):
...
@@ -182,6 +186,7 @@ class DistillationDilaDBLoss(DBLoss):
super
().
__init__
()
super
().
__init__
()
self
.
model_name_pairs
=
model_name_pairs
self
.
model_name_pairs
=
model_name_pairs
self
.
name
=
name
self
.
name
=
name
self
.
key
=
key
def
forward
(
self
,
predicts
,
batch
):
def
forward
(
self
,
predicts
,
batch
):
loss_dict
=
dict
()
loss_dict
=
dict
()
...
@@ -219,7 +224,7 @@ class DistillationDilaDBLoss(DBLoss):
...
@@ -219,7 +224,7 @@ class DistillationDilaDBLoss(DBLoss):
loss_dict
[
k
]
=
bce_loss
+
loss_binary_maps
loss_dict
[
k
]
=
bce_loss
+
loss_binary_maps
loss_dict
=
_sum_loss
(
loss_dict
)
loss_dict
=
_sum_loss
(
loss_dict
)
return
loss
return
loss
_dict
class
DistillationDistanceLoss
(
DistanceLoss
):
class
DistillationDistanceLoss
(
DistanceLoss
):
...
...
ppocr/modeling/architectures/distillation_model.py
浏览文件 @
185d1e1f
...
@@ -21,7 +21,7 @@ from ppocr.modeling.backbones import build_backbone
...
@@ -21,7 +21,7 @@ from ppocr.modeling.backbones import build_backbone
from
ppocr.modeling.necks
import
build_neck
from
ppocr.modeling.necks
import
build_neck
from
ppocr.modeling.heads
import
build_head
from
ppocr.modeling.heads
import
build_head
from
.base_model
import
BaseModel
from
.base_model
import
BaseModel
from
ppocr.utils.save_load
import
init_model
from
ppocr.utils.save_load
import
init_model
,
load_pretrained_params
__all__
=
[
'DistillationModel'
]
__all__
=
[
'DistillationModel'
]
...
@@ -46,7 +46,7 @@ class DistillationModel(nn.Layer):
...
@@ -46,7 +46,7 @@ class DistillationModel(nn.Layer):
pretrained
=
model_config
.
pop
(
"pretrained"
)
pretrained
=
model_config
.
pop
(
"pretrained"
)
model
=
BaseModel
(
model_config
)
model
=
BaseModel
(
model_config
)
if
pretrained
is
not
None
:
if
pretrained
is
not
None
:
init_model
(
model
,
path
=
pretrained
)
load_pretrained_params
(
model
,
pretrained
)
if
freeze_params
:
if
freeze_params
:
for
param
in
model
.
parameters
():
for
param
in
model
.
parameters
():
param
.
trainable
=
False
param
.
trainable
=
False
...
...
ppocr/postprocess/__init__.py
浏览文件 @
185d1e1f
...
@@ -21,7 +21,7 @@ import copy
...
@@ -21,7 +21,7 @@ import copy
__all__
=
[
'build_post_process'
]
__all__
=
[
'build_post_process'
]
from
.db_postprocess
import
DBPostProcess
from
.db_postprocess
import
DBPostProcess
,
DistillationDBPostProcess
from
.east_postprocess
import
EASTPostProcess
from
.east_postprocess
import
EASTPostProcess
from
.sast_postprocess
import
SASTPostProcess
from
.sast_postprocess
import
SASTPostProcess
from
.rec_postprocess
import
CTCLabelDecode
,
AttnLabelDecode
,
SRNLabelDecode
,
DistillationCTCLabelDecode
,
\
from
.rec_postprocess
import
CTCLabelDecode
,
AttnLabelDecode
,
SRNLabelDecode
,
DistillationCTCLabelDecode
,
\
...
@@ -34,7 +34,7 @@ def build_post_process(config, global_config=None):
...
@@ -34,7 +34,7 @@ def build_post_process(config, global_config=None):
support_dict
=
[
support_dict
=
[
'DBPostProcess'
,
'EASTPostProcess'
,
'SASTPostProcess'
,
'CTCLabelDecode'
,
'DBPostProcess'
,
'EASTPostProcess'
,
'SASTPostProcess'
,
'CTCLabelDecode'
,
'AttnLabelDecode'
,
'ClsPostProcess'
,
'SRNLabelDecode'
,
'PGPostProcess'
,
'AttnLabelDecode'
,
'ClsPostProcess'
,
'SRNLabelDecode'
,
'PGPostProcess'
,
'DistillationCTCLabelDecode'
,
'TableLabelDecode'
'DistillationCTCLabelDecode'
,
'TableLabelDecode'
,
'DistillationDBPostProcess'
]
]
config
=
copy
.
deepcopy
(
config
)
config
=
copy
.
deepcopy
(
config
)
...
...
ppocr/postprocess/db_postprocess.py
浏览文件 @
185d1e1f
...
@@ -187,3 +187,44 @@ class DBPostProcess(object):
...
@@ -187,3 +187,44 @@ class DBPostProcess(object):
boxes_batch
.
append
({
'points'
:
boxes
})
boxes_batch
.
append
({
'points'
:
boxes
})
return
boxes_batch
return
boxes_batch
class
DistillationDBPostProcess
(
DBPostProcess
):
def
__init__
(
self
,
model_name
=
[
"student"
],
key
=
None
,
thresh
=
0.3
,
box_thresh
=
0.7
,
max_candidates
=
1000
,
unclip_ratio
=
2.0
,
use_dilation
=
False
,
score_mode
=
"fast"
,
**
kwargs
):
super
(
DistillationDBPostProcess
,
self
).
__init__
(
thresh
,
box_thresh
,
max_candidates
,
unclip_ratio
,
use_dilation
,
score_mode
)
if
not
isinstance
(
model_name
,
list
):
model_name
=
[
model_name
]
self
.
model_name
=
model_name
self
.
key
=
key
def
forward
(
self
,
predicts
,
shape_list
):
results
=
{}
for
name
in
self
.
model_name
:
pred
=
predicts
[
name
]
if
self
.
key
is
not
None
:
pred
=
pred
[
self
.
key
]
results
[
name
]
=
super
().
__call__
(
pred
,
shape_list
=
label
)
return
results
ppocr/utils/save_load.py
浏览文件 @
185d1e1f
...
@@ -116,6 +116,26 @@ def load_dygraph_params(config, model, logger, optimizer):
...
@@ -116,6 +116,26 @@ def load_dygraph_params(config, model, logger, optimizer):
logger
.
info
(
f
"loaded pretrained_model successful from
{
pm
}
"
)
logger
.
info
(
f
"loaded pretrained_model successful from
{
pm
}
"
)
return
{}
return
{}
def
load_pretrained_params
(
model
,
path
):
if
path
is
None
:
return
False
if
not
os
.
path
.
exists
(
path
)
and
not
os
.
path
.
exists
(
path
+
".pdparams"
):
print
(
f
"The pretrained_model
{
path
}
does not exists!"
)
return
False
path
=
path
if
path
.
endswith
(
'.pdparams'
)
else
path
+
'.pdparams'
params
=
paddle
.
load
(
path
)
state_dict
=
model
.
state_dict
()
new_state_dict
=
{}
for
k1
,
k2
in
zip
(
state_dict
.
keys
(),
params
.
keys
()):
if
list
(
state_dict
[
k1
].
shape
)
==
list
(
params
[
k2
].
shape
):
new_state_dict
[
k1
]
=
params
[
k2
]
else
:
print
(
f
"The shape of model params
{
k1
}
{
state_dict
[
k1
].
shape
}
not matched with loaded params
{
k2
}
{
params
[
k2
].
shape
}
!"
)
model
.
set_state_dict
(
new_state_dict
)
return
True
def
save_model
(
model
,
def
save_model
(
model
,
optimizer
,
optimizer
,
...
...
tools/program.py
浏览文件 @
185d1e1f
...
@@ -186,7 +186,10 @@ def train(config,
...
@@ -186,7 +186,10 @@ def train(config,
model
.
train
()
model
.
train
()
use_srn
=
config
[
'Architecture'
][
'algorithm'
]
==
"SRN"
use_srn
=
config
[
'Architecture'
][
'algorithm'
]
==
"SRN"
model_type
=
config
[
'Architecture'
][
'model_type'
]
try
:
model_type
=
config
[
'Architecture'
][
'model_type'
]
except
:
model_type
=
None
if
'start_epoch'
in
best_model_dict
:
if
'start_epoch'
in
best_model_dict
:
start_epoch
=
best_model_dict
[
'start_epoch'
]
start_epoch
=
best_model_dict
[
'start_epoch'
]
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录