Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
weixin_41840029
PaddleOCR
提交
2062b509
P
PaddleOCR
项目概览
weixin_41840029
/
PaddleOCR
与 Fork 源项目一致
Fork自
PaddlePaddle / PaddleOCR
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleOCR
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
2062b509
编写于
7月 12, 2021
作者:
Z
zhoujun
提交者:
GitHub
7月 12, 2021
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #3273 from LDOUBLEV/distill
Add det distill
上级
02b75a50
76bb40fc
变更
16
隐藏空白更改
内联
并排
Showing
16 changed file
with
833 addition
and
39 deletion
+833
-39
configs/det/ch_ppocr_v2.1/ch_det_lite_train_cml_v2.1.yml
configs/det/ch_ppocr_v2.1/ch_det_lite_train_cml_v2.1.yml
+202
-0
configs/det/ch_ppocr_v2.1/ch_det_lite_train_distill_v2.1.yml
configs/det/ch_ppocr_v2.1/ch_det_lite_train_distill_v2.1.yml
+174
-0
configs/det/ch_ppocr_v2.1/ch_det_lite_train_dml_v2.1.yml
configs/det/ch_ppocr_v2.1/ch_det_lite_train_dml_v2.1.yml
+176
-0
ppocr/losses/basic_loss.py
ppocr/losses/basic_loss.py
+31
-6
ppocr/losses/combined_loss.py
ppocr/losses/combined_loss.py
+8
-7
ppocr/losses/distillation_loss.py
ppocr/losses/distillation_loss.py
+171
-9
ppocr/metrics/det_metric.py
ppocr/metrics/det_metric.py
+1
-0
ppocr/metrics/distillation_metric.py
ppocr/metrics/distillation_metric.py
+4
-7
ppocr/modeling/architectures/base_model.py
ppocr/modeling/architectures/base_model.py
+4
-1
ppocr/modeling/architectures/distillation_model.py
ppocr/modeling/architectures/distillation_model.py
+2
-2
ppocr/postprocess/__init__.py
ppocr/postprocess/__init__.py
+3
-2
ppocr/postprocess/db_postprocess.py
ppocr/postprocess/db_postprocess.py
+26
-0
ppocr/utils/save_load.py
ppocr/utils/save_load.py
+21
-0
tools/eval.py
tools/eval.py
+6
-3
tools/program.py
tools/program.py
+4
-1
tools/train.py
tools/train.py
+0
-1
未找到文件。
configs/det/ch_ppocr_v2.1/ch_det_lite_train_cml_v2.1.yml
0 → 100644
浏览文件 @
2062b509
Global
:
use_gpu
:
true
epoch_num
:
1200
log_smooth_window
:
20
print_batch_step
:
2
save_model_dir
:
./output/ch_db_mv3/
save_epoch_step
:
1200
# evaluation is run every 5000 iterations after the 4000th iteration
eval_batch_step
:
[
3000
,
2000
]
cal_metric_during_train
:
False
pretrained_model
:
./pretrain_models/MobileNetV3_large_x0_5_pretrained
checkpoints
:
save_inference_dir
:
use_visualdl
:
False
infer_img
:
doc/imgs_en/img_10.jpg
save_res_path
:
./output/det_db/predicts_db.txt
Architecture
:
name
:
DistillationModel
algorithm
:
Distillation
Models
:
Student
:
pretrained
:
./pretrain_models/MobileNetV3_large_x0_5_pretrained
freeze_params
:
false
return_all_feats
:
false
model_type
:
det
algorithm
:
DB
Backbone
:
name
:
MobileNetV3
scale
:
0.5
model_name
:
large
disable_se
:
True
Neck
:
name
:
DBFPN
out_channels
:
96
Head
:
name
:
DBHead
k
:
50
Student2
:
pretrained
:
./pretrain_models/MobileNetV3_large_x0_5_pretrained
freeze_params
:
false
return_all_feats
:
false
model_type
:
det
algorithm
:
DB
Transform
:
Backbone
:
name
:
MobileNetV3
scale
:
0.5
model_name
:
large
disable_se
:
True
Neck
:
name
:
DBFPN
out_channels
:
96
Head
:
name
:
DBHead
k
:
50
Teacher
:
pretrained
:
./pretrain_models/ch_ppocr_server_v2.0_det_train/best_accuracy
freeze_params
:
true
return_all_feats
:
false
model_type
:
det
algorithm
:
DB
Transform
:
Backbone
:
name
:
ResNet
layers
:
18
Neck
:
name
:
DBFPN
out_channels
:
256
Head
:
name
:
DBHead
k
:
50
Loss
:
name
:
CombinedLoss
loss_config_list
:
-
DistillationDilaDBLoss
:
weight
:
1.0
model_name_pairs
:
-
[
"
Student"
,
"
Teacher"
]
-
[
"
Student2"
,
"
Teacher"
]
key
:
maps
balance_loss
:
true
main_loss_type
:
DiceLoss
alpha
:
5
beta
:
10
ohem_ratio
:
3
-
DistillationDMLLoss
:
model_name_pairs
:
-
[
"
Student"
,
"
Student2"
]
maps_name
:
"
thrink_maps"
weight
:
1.0
# act: None
model_name_pairs
:
[
"
Student"
,
"
Student2"
]
key
:
maps
-
DistillationDBLoss
:
weight
:
1.0
model_name_list
:
[
"
Student"
,
"
Student2"
]
# key: maps
# name: DBLoss
balance_loss
:
true
main_loss_type
:
DiceLoss
alpha
:
5
beta
:
10
ohem_ratio
:
3
Optimizer
:
name
:
Adam
beta1
:
0.9
beta2
:
0.999
lr
:
name
:
Cosine
learning_rate
:
0.001
warmup_epoch
:
2
regularizer
:
name
:
'
L2'
factor
:
0
PostProcess
:
name
:
DistillationDBPostProcess
model_name
:
[
"
Student"
,
"
Student2"
,
"
Teacher"
]
# key: maps
thresh
:
0.3
box_thresh
:
0.6
max_candidates
:
1000
unclip_ratio
:
1.5
Metric
:
name
:
DistillationMetric
base_metric_name
:
DetMetric
main_indicator
:
hmean
key
:
"
Student"
Train
:
dataset
:
name
:
SimpleDataSet
data_dir
:
./train_data/icdar2015/text_localization/
label_file_list
:
-
./train_data/icdar2015/text_localization/train_icdar2015_label.txt
ratio_list
:
[
1.0
]
transforms
:
-
DecodeImage
:
# load image
img_mode
:
BGR
channel_first
:
False
-
DetLabelEncode
:
# Class handling label
-
IaaAugment
:
augmenter_args
:
-
{
'
type'
:
Fliplr
,
'
args'
:
{
'
p'
:
0.5
}
}
-
{
'
type'
:
Affine
,
'
args'
:
{
'
rotate'
:
[
-10
,
10
]
}
}
-
{
'
type'
:
Resize
,
'
args'
:
{
'
size'
:
[
0.5
,
3
]
}
}
-
EastRandomCropData
:
size
:
[
960
,
960
]
max_tries
:
50
keep_ratio
:
true
-
MakeBorderMap
:
shrink_ratio
:
0.4
thresh_min
:
0.3
thresh_max
:
0.7
-
MakeShrinkMap
:
shrink_ratio
:
0.4
min_text_size
:
8
-
NormalizeImage
:
scale
:
1./255.
mean
:
[
0.485
,
0.456
,
0.406
]
std
:
[
0.229
,
0.224
,
0.225
]
order
:
'
hwc'
-
ToCHWImage
:
-
KeepKeys
:
keep_keys
:
[
'
image'
,
'
threshold_map'
,
'
threshold_mask'
,
'
shrink_map'
,
'
shrink_mask'
]
# the order of the dataloader list
loader
:
shuffle
:
True
drop_last
:
False
batch_size_per_card
:
8
num_workers
:
4
Eval
:
dataset
:
name
:
SimpleDataSet
data_dir
:
./train_data/icdar2015/text_localization/
label_file_list
:
-
./train_data/icdar2015/text_localization/test_icdar2015_label.txt
transforms
:
-
DecodeImage
:
# load image
img_mode
:
BGR
channel_first
:
False
-
DetLabelEncode
:
# Class handling label
-
DetResizeForTest
:
# image_shape: [736, 1280]
-
NormalizeImage
:
scale
:
1./255.
mean
:
[
0.485
,
0.456
,
0.406
]
std
:
[
0.229
,
0.224
,
0.225
]
order
:
'
hwc'
-
ToCHWImage
:
-
KeepKeys
:
keep_keys
:
[
'
image'
,
'
shape'
,
'
polys'
,
'
ignore_tags'
]
loader
:
shuffle
:
False
drop_last
:
False
batch_size_per_card
:
1
# must be 1
num_workers
:
2
configs/det/ch_ppocr_v2.1/ch_det_lite_train_distill_v2.1.yml
0 → 100644
浏览文件 @
2062b509
Global
:
use_gpu
:
true
epoch_num
:
1200
log_smooth_window
:
20
print_batch_step
:
2
save_model_dir
:
./output/ch_db_mv3/
save_epoch_step
:
1200
# evaluation is run every 5000 iterations after the 4000th iteration
eval_batch_step
:
[
3000
,
2000
]
cal_metric_during_train
:
False
pretrained_model
:
./pretrain_models/MobileNetV3_large_x0_5_pretrained
checkpoints
:
save_inference_dir
:
use_visualdl
:
False
infer_img
:
doc/imgs_en/img_10.jpg
save_res_path
:
./output/det_db/predicts_db.txt
Architecture
:
name
:
DistillationModel
algorithm
:
Distillation
Models
:
Student
:
pretrained
:
./pretrain_models/MobileNetV3_large_x0_5_pretrained
freeze_params
:
false
return_all_feats
:
false
model_type
:
det
algorithm
:
DB
Backbone
:
name
:
MobileNetV3
scale
:
0.5
model_name
:
large
disable_se
:
True
Neck
:
name
:
DBFPN
out_channels
:
96
Head
:
name
:
DBHead
k
:
50
Teacher
:
pretrained
:
./pretrain_models/ch_ppocr_server_v2.0_det_train/best_accuracy
freeze_params
:
true
return_all_feats
:
false
model_type
:
det
algorithm
:
DB
Transform
:
Backbone
:
name
:
ResNet
layers
:
18
Neck
:
name
:
DBFPN
out_channels
:
256
Head
:
name
:
DBHead
k
:
50
Loss
:
name
:
CombinedLoss
loss_config_list
:
-
DistillationDilaDBLoss
:
weight
:
1.0
model_name_pairs
:
-
[
"
Student"
,
"
Teacher"
]
key
:
maps
balance_loss
:
true
main_loss_type
:
DiceLoss
alpha
:
5
beta
:
10
ohem_ratio
:
3
-
DistillationDBLoss
:
weight
:
1.0
model_name_list
:
[
"
Student"
,
"
Teacher"
]
# key: maps
name
:
DBLoss
balance_loss
:
true
main_loss_type
:
DiceLoss
alpha
:
5
beta
:
10
ohem_ratio
:
3
Optimizer
:
name
:
Adam
beta1
:
0.9
beta2
:
0.999
lr
:
name
:
Cosine
learning_rate
:
0.001
warmup_epoch
:
2
regularizer
:
name
:
'
L2'
factor
:
0
PostProcess
:
name
:
DistillationDBPostProcess
model_name
:
[
"
Student"
,
"
Student2"
]
key
:
head_out
thresh
:
0.3
box_thresh
:
0.6
max_candidates
:
1000
unclip_ratio
:
1.5
Metric
:
name
:
DistillationMetric
base_metric_name
:
DetMetric
main_indicator
:
hmean
key
:
"
Student"
Train
:
dataset
:
name
:
SimpleDataSet
data_dir
:
./train_data/icdar2015/text_localization/
label_file_list
:
-
./train_data/icdar2015/text_localization/train_icdar2015_label.txt
ratio_list
:
[
1.0
]
transforms
:
-
DecodeImage
:
# load image
img_mode
:
BGR
channel_first
:
False
-
DetLabelEncode
:
# Class handling label
-
IaaAugment
:
augmenter_args
:
-
{
'
type'
:
Fliplr
,
'
args'
:
{
'
p'
:
0.5
}
}
-
{
'
type'
:
Affine
,
'
args'
:
{
'
rotate'
:
[
-10
,
10
]
}
}
-
{
'
type'
:
Resize
,
'
args'
:
{
'
size'
:
[
0.5
,
3
]
}
}
-
EastRandomCropData
:
size
:
[
960
,
960
]
max_tries
:
50
keep_ratio
:
true
-
MakeBorderMap
:
shrink_ratio
:
0.4
thresh_min
:
0.3
thresh_max
:
0.7
-
MakeShrinkMap
:
shrink_ratio
:
0.4
min_text_size
:
8
-
NormalizeImage
:
scale
:
1./255.
mean
:
[
0.485
,
0.456
,
0.406
]
std
:
[
0.229
,
0.224
,
0.225
]
order
:
'
hwc'
-
ToCHWImage
:
-
KeepKeys
:
keep_keys
:
[
'
image'
,
'
threshold_map'
,
'
threshold_mask'
,
'
shrink_map'
,
'
shrink_mask'
]
# the order of the dataloader list
loader
:
shuffle
:
True
drop_last
:
False
batch_size_per_card
:
8
num_workers
:
4
Eval
:
dataset
:
name
:
SimpleDataSet
data_dir
:
./train_data/icdar2015/text_localization/
label_file_list
:
-
./train_data/icdar2015/text_localization/test_icdar2015_label.txt
transforms
:
-
DecodeImage
:
# load image
img_mode
:
BGR
channel_first
:
False
-
DetLabelEncode
:
# Class handling label
-
DetResizeForTest
:
# image_shape: [736, 1280]
-
NormalizeImage
:
scale
:
1./255.
mean
:
[
0.485
,
0.456
,
0.406
]
std
:
[
0.229
,
0.224
,
0.225
]
order
:
'
hwc'
-
ToCHWImage
:
-
KeepKeys
:
keep_keys
:
[
'
image'
,
'
shape'
,
'
polys'
,
'
ignore_tags'
]
loader
:
shuffle
:
False
drop_last
:
False
batch_size_per_card
:
1
# must be 1
num_workers
:
2
configs/det/ch_ppocr_v2.1/ch_det_lite_train_dml_v2.1.yml
0 → 100644
浏览文件 @
2062b509
Global
:
use_gpu
:
true
epoch_num
:
1200
log_smooth_window
:
20
print_batch_step
:
2
save_model_dir
:
./output/ch_db_mv3/
save_epoch_step
:
1200
# evaluation is run every 5000 iterations after the 4000th iteration
eval_batch_step
:
[
3000
,
2000
]
cal_metric_during_train
:
False
pretrained_model
:
./pretrain_models/MobileNetV3_large_x0_5_pretrained
checkpoints
:
save_inference_dir
:
use_visualdl
:
False
infer_img
:
doc/imgs_en/img_10.jpg
save_res_path
:
./output/det_db/predicts_db.txt
Architecture
:
name
:
DistillationModel
algorithm
:
Distillation
Models
:
Student
:
pretrained
:
./pretrain_models/MobileNetV3_large_x0_5_pretrained
freeze_params
:
false
return_all_feats
:
false
model_type
:
det
algorithm
:
DB
Backbone
:
name
:
MobileNetV3
scale
:
0.5
model_name
:
large
disable_se
:
True
Neck
:
name
:
DBFPN
out_channels
:
96
Head
:
name
:
DBHead
k
:
50
Student2
:
pretrained
:
./pretrain_models/MobileNetV3_large_x0_5_pretrained
freeze_params
:
false
return_all_feats
:
false
model_type
:
det
algorithm
:
DB
Transform
:
Backbone
:
name
:
MobileNetV3
scale
:
0.5
model_name
:
large
disable_se
:
True
Neck
:
name
:
DBFPN
out_channels
:
96
Head
:
name
:
DBHead
k
:
50
Loss
:
name
:
CombinedLoss
loss_config_list
:
-
DistillationDMLLoss
:
model_name_pairs
:
-
[
"
Student"
,
"
Student2"
]
maps_name
:
"
thrink_maps"
weight
:
1.0
act
:
"
softmax"
model_name_pairs
:
[
"
Student"
,
"
Student2"
]
key
:
maps
-
DistillationDBLoss
:
weight
:
1.0
model_name_list
:
[
"
Student"
,
"
Student2"
]
# key: maps
name
:
DBLoss
balance_loss
:
true
main_loss_type
:
DiceLoss
alpha
:
5
beta
:
10
ohem_ratio
:
3
Optimizer
:
name
:
Adam
beta1
:
0.9
beta2
:
0.999
lr
:
name
:
Cosine
learning_rate
:
0.001
warmup_epoch
:
2
regularizer
:
name
:
'
L2'
factor
:
0
PostProcess
:
name
:
DistillationDBPostProcess
model_name
:
[
"
Student"
,
"
Student2"
]
key
:
head_out
thresh
:
0.3
box_thresh
:
0.6
max_candidates
:
1000
unclip_ratio
:
1.5
Metric
:
name
:
DistillationMetric
base_metric_name
:
DetMetric
main_indicator
:
hmean
key
:
"
Student"
Train
:
dataset
:
name
:
SimpleDataSet
data_dir
:
./train_data/icdar2015/text_localization/
label_file_list
:
-
./train_data/icdar2015/text_localization/train_icdar2015_label.txt
ratio_list
:
[
1.0
]
transforms
:
-
DecodeImage
:
# load image
img_mode
:
BGR
channel_first
:
False
-
DetLabelEncode
:
# Class handling label
-
IaaAugment
:
augmenter_args
:
-
{
'
type'
:
Fliplr
,
'
args'
:
{
'
p'
:
0.5
}
}
-
{
'
type'
:
Affine
,
'
args'
:
{
'
rotate'
:
[
-10
,
10
]
}
}
-
{
'
type'
:
Resize
,
'
args'
:
{
'
size'
:
[
0.5
,
3
]
}
}
-
EastRandomCropData
:
size
:
[
960
,
960
]
max_tries
:
50
keep_ratio
:
true
-
MakeBorderMap
:
shrink_ratio
:
0.4
thresh_min
:
0.3
thresh_max
:
0.7
-
MakeShrinkMap
:
shrink_ratio
:
0.4
min_text_size
:
8
-
NormalizeImage
:
scale
:
1./255.
mean
:
[
0.485
,
0.456
,
0.406
]
std
:
[
0.229
,
0.224
,
0.225
]
order
:
'
hwc'
-
ToCHWImage
:
-
KeepKeys
:
keep_keys
:
[
'
image'
,
'
threshold_map'
,
'
threshold_mask'
,
'
shrink_map'
,
'
shrink_mask'
]
# the order of the dataloader list
loader
:
shuffle
:
True
drop_last
:
False
batch_size_per_card
:
8
num_workers
:
4
Eval
:
dataset
:
name
:
SimpleDataSet
data_dir
:
./train_data/icdar2015/text_localization/
label_file_list
:
-
./train_data/icdar2015/text_localization/test_icdar2015_label.txt
transforms
:
-
DecodeImage
:
# load image
img_mode
:
BGR
channel_first
:
False
-
DetLabelEncode
:
# Class handling label
-
DetResizeForTest
:
# image_shape: [736, 1280]
-
NormalizeImage
:
scale
:
1./255.
mean
:
[
0.485
,
0.456
,
0.406
]
std
:
[
0.229
,
0.224
,
0.225
]
order
:
'
hwc'
-
ToCHWImage
:
-
KeepKeys
:
keep_keys
:
[
'
image'
,
'
shape'
,
'
polys'
,
'
ignore_tags'
]
loader
:
shuffle
:
False
drop_last
:
False
batch_size_per_card
:
1
# must be 1
num_workers
:
2
ppocr/losses/basic_loss.py
浏览文件 @
2062b509
...
...
@@ -54,6 +54,27 @@ class CELoss(nn.Layer):
return
loss
class
KLJSLoss
(
object
):
def
__init__
(
self
,
mode
=
'kl'
):
assert
mode
in
[
'kl'
,
'js'
,
'KL'
,
'JS'
],
"mode can only be one of ['kl', 'js', 'KL', 'JS']"
self
.
mode
=
mode
def
__call__
(
self
,
p1
,
p2
,
reduction
=
"mean"
):
loss
=
paddle
.
multiply
(
p2
,
paddle
.
log
(
(
p2
+
1e-5
)
/
(
p1
+
1e-5
)
+
1e-5
))
if
self
.
mode
.
lower
()
==
"js"
:
loss
+=
paddle
.
multiply
(
p1
,
paddle
.
log
((
p1
+
1e-5
)
/
(
p2
+
1e-5
)
+
1e-5
))
loss
*=
0.5
if
reduction
==
"mean"
:
loss
=
paddle
.
mean
(
loss
,
axis
=
[
1
,
2
])
elif
reduction
==
"none"
or
reduction
is
None
:
return
loss
else
:
loss
=
paddle
.
sum
(
loss
,
axis
=
[
1
,
2
])
return
loss
class
DMLLoss
(
nn
.
Layer
):
"""
DMLLoss
...
...
@@ -69,17 +90,21 @@ class DMLLoss(nn.Layer):
self
.
act
=
nn
.
Sigmoid
()
else
:
self
.
act
=
None
self
.
jskl_loss
=
KLJSLoss
(
mode
=
"js"
)
def
forward
(
self
,
out1
,
out2
):
if
self
.
act
is
not
None
:
out1
=
self
.
act
(
out1
)
out2
=
self
.
act
(
out2
)
log_out1
=
paddle
.
log
(
out1
)
log_out2
=
paddle
.
log
(
out2
)
loss
=
(
F
.
kl_div
(
log_out1
,
out2
,
reduction
=
'batchmean'
)
+
F
.
kl_div
(
log_out2
,
out1
,
reduction
=
'batchmean'
))
/
2.0
if
len
(
out1
.
shape
)
<
2
:
log_out1
=
paddle
.
log
(
out1
)
log_out2
=
paddle
.
log
(
out2
)
loss
=
(
F
.
kl_div
(
log_out1
,
out2
,
reduction
=
'batchmean'
)
+
F
.
kl_div
(
log_out2
,
out1
,
reduction
=
'batchmean'
))
/
2.0
else
:
loss
=
self
.
jskl_loss
(
out1
,
out2
)
return
loss
...
...
ppocr/losses/combined_loss.py
浏览文件 @
2062b509
...
...
@@ -17,7 +17,7 @@ import paddle.nn as nn
from
.distillation_loss
import
DistillationCTCLoss
from
.distillation_loss
import
DistillationDMLLoss
from
.distillation_loss
import
DistillationDistanceLoss
from
.distillation_loss
import
DistillationDistanceLoss
,
DistillationDBLoss
,
DistillationDilaDBLoss
class
CombinedLoss
(
nn
.
Layer
):
...
...
@@ -44,15 +44,16 @@ class CombinedLoss(nn.Layer):
def
forward
(
self
,
input
,
batch
,
**
kargs
):
loss_dict
=
{}
loss_all
=
0.
for
idx
,
loss_func
in
enumerate
(
self
.
loss_func
):
loss
=
loss_func
(
input
,
batch
,
**
kargs
)
if
isinstance
(
loss
,
paddle
.
Tensor
):
loss
=
{
"loss_{}_{}"
.
format
(
str
(
loss
),
idx
):
loss
}
weight
=
self
.
loss_weight
[
idx
]
loss
=
{
"{}_{}"
.
format
(
key
,
idx
):
loss
[
key
]
*
weight
for
key
in
loss
}
loss_dict
.
update
(
loss
)
loss_dict
[
"loss"
]
=
paddle
.
add_n
(
list
(
loss_dict
.
values
()))
for
key
in
loss
.
keys
():
if
key
==
"loss"
:
loss_all
+=
loss
[
key
]
*
weight
else
:
loss_dict
[
"{}_{}"
.
format
(
key
,
idx
)]
=
loss
[
key
]
loss_dict
[
"loss"
]
=
loss_all
return
loss_dict
ppocr/losses/distillation_loss.py
浏览文件 @
2062b509
...
...
@@ -14,23 +14,76 @@
import
paddle
import
paddle.nn
as
nn
import
numpy
as
np
import
cv2
from
.rec_ctc_loss
import
CTCLoss
from
.basic_loss
import
DMLLoss
from
.basic_loss
import
DistanceLoss
from
.det_db_loss
import
DBLoss
from
.det_basic_loss
import
BalanceLoss
,
MaskL1Loss
,
DiceLoss
def
_sum_loss
(
loss_dict
):
if
"loss"
in
loss_dict
.
keys
():
return
loss_dict
else
:
loss_dict
[
"loss"
]
=
0.
for
k
,
value
in
loss_dict
.
items
():
if
k
==
"loss"
:
continue
else
:
loss_dict
[
"loss"
]
+=
value
return
loss_dict
class
DistillationDMLLoss
(
DMLLoss
):
"""
"""
def
__init__
(
self
,
model_name_pairs
=
[],
act
=
None
,
key
=
None
,
name
=
"loss_dml"
):
def
__init__
(
self
,
model_name_pairs
=
[],
act
=
None
,
key
=
None
,
maps_name
=
None
,
name
=
"dml"
):
super
().
__init__
(
act
=
act
)
assert
isinstance
(
model_name_pairs
,
list
)
self
.
key
=
key
self
.
model_name_pairs
=
model_name_pairs
self
.
model_name_pairs
=
self
.
_check_model_name_pairs
(
model_name_pairs
)
self
.
name
=
name
self
.
maps_name
=
self
.
_check_maps_name
(
maps_name
)
def
_check_model_name_pairs
(
self
,
model_name_pairs
):
if
not
isinstance
(
model_name_pairs
,
list
):
return
[]
elif
isinstance
(
model_name_pairs
[
0
],
list
)
and
isinstance
(
model_name_pairs
[
0
][
0
],
str
):
return
model_name_pairs
else
:
return
[
model_name_pairs
]
def
_check_maps_name
(
self
,
maps_name
):
if
maps_name
is
None
:
return
None
elif
type
(
maps_name
)
==
str
:
return
[
maps_name
]
elif
type
(
maps_name
)
==
list
:
return
[
maps_name
]
else
:
return
None
def
_slice_out
(
self
,
outs
):
new_outs
=
{}
for
k
in
self
.
maps_name
:
if
k
==
"thrink_maps"
:
new_outs
[
k
]
=
outs
[:,
0
,
:,
:]
elif
k
==
"threshold_maps"
:
new_outs
[
k
]
=
outs
[:,
1
,
:,
:]
elif
k
==
"binary_maps"
:
new_outs
[
k
]
=
outs
[:,
2
,
:,
:]
else
:
continue
return
new_outs
def
forward
(
self
,
predicts
,
batch
):
loss_dict
=
dict
()
...
...
@@ -40,13 +93,30 @@ class DistillationDMLLoss(DMLLoss):
if
self
.
key
is
not
None
:
out1
=
out1
[
self
.
key
]
out2
=
out2
[
self
.
key
]
loss
=
super
().
forward
(
out1
,
out2
)
if
isinstance
(
loss
,
dict
):
for
key
in
loss
:
loss_dict
[
"{}_{}_{}_{}"
.
format
(
key
,
pair
[
0
],
pair
[
1
],
idx
)]
=
loss
[
key
]
if
self
.
maps_name
is
None
:
loss
=
super
().
forward
(
out1
,
out2
)
if
isinstance
(
loss
,
dict
):
for
key
in
loss
:
loss_dict
[
"{}_{}_{}_{}"
.
format
(
key
,
pair
[
0
],
pair
[
1
],
idx
)]
=
loss
[
key
]
else
:
loss_dict
[
"{}_{}"
.
format
(
self
.
name
,
idx
)]
=
loss
else
:
loss_dict
[
"{}_{}"
.
format
(
self
.
name
,
idx
)]
=
loss
outs1
=
self
.
_slice_out
(
out1
)
outs2
=
self
.
_slice_out
(
out2
)
for
_c
,
k
in
enumerate
(
outs1
.
keys
()):
loss
=
super
().
forward
(
outs1
[
k
],
outs2
[
k
])
if
isinstance
(
loss
,
dict
):
for
key
in
loss
:
loss_dict
[
"{}_{}_{}_{}_{}"
.
format
(
key
,
pair
[
0
],
pair
[
1
],
map_name
,
idx
)]
=
loss
[
key
]
else
:
loss_dict
[
"{}_{}_{}"
.
format
(
self
.
name
,
self
.
maps_name
[
_c
],
idx
)]
=
loss
loss_dict
=
_sum_loss
(
loss_dict
)
return
loss_dict
...
...
@@ -73,6 +143,98 @@ class DistillationCTCLoss(CTCLoss):
return
loss_dict
class
DistillationDBLoss
(
DBLoss
):
def
__init__
(
self
,
model_name_list
=
[],
balance_loss
=
True
,
main_loss_type
=
'DiceLoss'
,
alpha
=
5
,
beta
=
10
,
ohem_ratio
=
3
,
eps
=
1e-6
,
name
=
"db"
,
**
kwargs
):
super
().
__init__
()
self
.
model_name_list
=
model_name_list
self
.
name
=
name
self
.
key
=
None
def
forward
(
self
,
predicts
,
batch
):
loss_dict
=
{}
for
idx
,
model_name
in
enumerate
(
self
.
model_name_list
):
out
=
predicts
[
model_name
]
if
self
.
key
is
not
None
:
out
=
out
[
self
.
key
]
loss
=
super
().
forward
(
out
,
batch
)
if
isinstance
(
loss
,
dict
):
for
key
in
loss
.
keys
():
if
key
==
"loss"
:
continue
name
=
"{}_{}_{}"
.
format
(
self
.
name
,
model_name
,
key
)
loss_dict
[
name
]
=
loss
[
key
]
else
:
loss_dict
[
"{}_{}"
.
format
(
self
.
name
,
model_name
)]
=
loss
loss_dict
=
_sum_loss
(
loss_dict
)
return
loss_dict
class
DistillationDilaDBLoss
(
DBLoss
):
def
__init__
(
self
,
model_name_pairs
=
[],
key
=
None
,
balance_loss
=
True
,
main_loss_type
=
'DiceLoss'
,
alpha
=
5
,
beta
=
10
,
ohem_ratio
=
3
,
eps
=
1e-6
,
name
=
"dila_dbloss"
):
super
().
__init__
()
self
.
model_name_pairs
=
model_name_pairs
self
.
name
=
name
self
.
key
=
key
def
forward
(
self
,
predicts
,
batch
):
loss_dict
=
dict
()
for
idx
,
pair
in
enumerate
(
self
.
model_name_pairs
):
stu_outs
=
predicts
[
pair
[
0
]]
tch_outs
=
predicts
[
pair
[
1
]]
if
self
.
key
is
not
None
:
stu_preds
=
stu_outs
[
self
.
key
]
tch_preds
=
tch_outs
[
self
.
key
]
stu_shrink_maps
=
stu_preds
[:,
0
,
:,
:]
stu_binary_maps
=
stu_preds
[:,
2
,
:,
:]
# dilation to teacher prediction
dilation_w
=
np
.
array
([[
1
,
1
],
[
1
,
1
]])
th_shrink_maps
=
tch_preds
[:,
0
,
:,
:]
th_shrink_maps
=
th_shrink_maps
.
numpy
()
>
0.3
# thresh = 0.3
dilate_maps
=
np
.
zeros_like
(
th_shrink_maps
).
astype
(
np
.
float32
)
for
i
in
range
(
th_shrink_maps
.
shape
[
0
]):
dilate_maps
[
i
]
=
cv2
.
dilate
(
th_shrink_maps
[
i
,
:,
:].
astype
(
np
.
uint8
),
dilation_w
)
th_shrink_maps
=
paddle
.
to_tensor
(
dilate_maps
)
label_threshold_map
,
label_threshold_mask
,
label_shrink_map
,
label_shrink_mask
=
batch
[
1
:]
# calculate the shrink map loss
bce_loss
=
self
.
alpha
*
self
.
bce_loss
(
stu_shrink_maps
,
th_shrink_maps
,
label_shrink_mask
)
loss_binary_maps
=
self
.
dice_loss
(
stu_binary_maps
,
th_shrink_maps
,
label_shrink_mask
)
# k = f"{self.name}_{pair[0]}_{pair[1]}"
k
=
"{}_{}_{}"
.
format
(
self
.
name
,
pair
[
0
],
pair
[
1
])
loss_dict
[
k
]
=
bce_loss
+
loss_binary_maps
loss_dict
=
_sum_loss
(
loss_dict
)
return
loss_dict
class
DistillationDistanceLoss
(
DistanceLoss
):
"""
"""
...
...
ppocr/metrics/det_metric.py
浏览文件 @
2062b509
...
...
@@ -55,6 +55,7 @@ class DetMetric(object):
result
=
self
.
evaluator
.
evaluate_image
(
gt_info_list
,
det_info_list
)
self
.
results
.
append
(
result
)
def
get_metric
(
self
):
"""
return metrics {
...
...
ppocr/metrics/distillation_metric.py
浏览文件 @
2062b509
...
...
@@ -24,8 +24,8 @@ from .cls_metric import ClsMetric
class
DistillationMetric
(
object
):
def
__init__
(
self
,
key
=
None
,
base_metric_name
=
"RecMetric"
,
main_indicator
=
'acc'
,
base_metric_name
=
None
,
main_indicator
=
None
,
**
kwargs
):
self
.
main_indicator
=
main_indicator
self
.
key
=
key
...
...
@@ -42,16 +42,13 @@ class DistillationMetric(object):
main_indicator
=
self
.
main_indicator
,
**
self
.
kwargs
)
self
.
metrics
[
key
].
reset
()
def
__call__
(
self
,
preds
,
*
args
,
**
kwargs
):
def
__call__
(
self
,
preds
,
batch
,
**
kwargs
):
assert
isinstance
(
preds
,
dict
)
if
self
.
metrics
is
None
:
self
.
_init_metrcis
(
preds
)
output
=
dict
()
for
key
in
preds
:
metric
=
self
.
metrics
[
key
].
__call__
(
preds
[
key
],
*
args
,
**
kwargs
)
for
sub_key
in
metric
:
output
[
"{}_{}"
.
format
(
key
,
sub_key
)]
=
metric
[
sub_key
]
return
output
self
.
metrics
[
key
].
__call__
(
preds
[
key
],
batch
,
**
kwargs
)
def
get_metric
(
self
):
"""
...
...
ppocr/modeling/architectures/base_model.py
浏览文件 @
2062b509
...
...
@@ -79,7 +79,10 @@ class BaseModel(nn.Layer):
x
=
self
.
neck
(
x
)
y
[
"neck_out"
]
=
x
x
=
self
.
head
(
x
,
targets
=
data
)
y
[
"head_out"
]
=
x
if
isinstance
(
x
,
dict
):
y
.
update
(
x
)
else
:
y
[
"head_out"
]
=
x
if
self
.
return_all_feats
:
return
y
else
:
...
...
ppocr/modeling/architectures/distillation_model.py
浏览文件 @
2062b509
...
...
@@ -21,7 +21,7 @@ from ppocr.modeling.backbones import build_backbone
from
ppocr.modeling.necks
import
build_neck
from
ppocr.modeling.heads
import
build_head
from
.base_model
import
BaseModel
from
ppocr.utils.save_load
import
init_model
from
ppocr.utils.save_load
import
init_model
,
load_pretrained_params
__all__
=
[
'DistillationModel'
]
...
...
@@ -46,7 +46,7 @@ class DistillationModel(nn.Layer):
pretrained
=
model_config
.
pop
(
"pretrained"
)
model
=
BaseModel
(
model_config
)
if
pretrained
is
not
None
:
init_model
(
model
,
path
=
pretrained
)
model
=
load_pretrained_params
(
model
,
pretrained
)
if
freeze_params
:
for
param
in
model
.
parameters
():
param
.
trainable
=
False
...
...
ppocr/postprocess/__init__.py
浏览文件 @
2062b509
...
...
@@ -21,7 +21,7 @@ import copy
__all__
=
[
'build_post_process'
]
from
.db_postprocess
import
DBPostProcess
from
.db_postprocess
import
DBPostProcess
,
DistillationDBPostProcess
from
.east_postprocess
import
EASTPostProcess
from
.sast_postprocess
import
SASTPostProcess
from
.rec_postprocess
import
CTCLabelDecode
,
AttnLabelDecode
,
SRNLabelDecode
,
DistillationCTCLabelDecode
,
\
...
...
@@ -34,7 +34,8 @@ def build_post_process(config, global_config=None):
support_dict
=
[
'DBPostProcess'
,
'EASTPostProcess'
,
'SASTPostProcess'
,
'CTCLabelDecode'
,
'AttnLabelDecode'
,
'ClsPostProcess'
,
'SRNLabelDecode'
,
'PGPostProcess'
,
'DistillationCTCLabelDecode'
,
'TableLabelDecode'
'DistillationCTCLabelDecode'
,
'TableLabelDecode'
,
'DistillationDBPostProcess'
]
config
=
copy
.
deepcopy
(
config
)
...
...
ppocr/postprocess/db_postprocess.py
浏览文件 @
2062b509
...
...
@@ -187,3 +187,29 @@ class DBPostProcess(object):
boxes_batch
.
append
({
'points'
:
boxes
})
return
boxes_batch
class
DistillationDBPostProcess
(
object
):
def
__init__
(
self
,
model_name
=
[
"student"
],
key
=
None
,
thresh
=
0.3
,
box_thresh
=
0.6
,
max_candidates
=
1000
,
unclip_ratio
=
1.5
,
use_dilation
=
False
,
score_mode
=
"fast"
,
**
kwargs
):
self
.
model_name
=
model_name
self
.
key
=
key
self
.
post_process
=
DBPostProcess
(
thresh
=
thresh
,
box_thresh
=
box_thresh
,
max_candidates
=
max_candidates
,
unclip_ratio
=
unclip_ratio
,
use_dilation
=
use_dilation
,
score_mode
=
score_mode
)
def
__call__
(
self
,
predicts
,
shape_list
):
results
=
{}
for
k
in
self
.
model_name
:
results
[
k
]
=
self
.
post_process
(
predicts
[
k
],
shape_list
=
shape_list
)
return
results
ppocr/utils/save_load.py
浏览文件 @
2062b509
...
...
@@ -116,6 +116,27 @@ def load_dygraph_params(config, model, logger, optimizer):
logger
.
info
(
f
"loaded pretrained_model successful from
{
pm
}
"
)
return
{}
def
load_pretrained_params
(
model
,
path
):
if
path
is
None
:
return
False
if
not
os
.
path
.
exists
(
path
)
and
not
os
.
path
.
exists
(
path
+
".pdparams"
):
print
(
f
"The pretrained_model
{
path
}
does not exists!"
)
return
False
path
=
path
if
path
.
endswith
(
'.pdparams'
)
else
path
+
'.pdparams'
params
=
paddle
.
load
(
path
)
state_dict
=
model
.
state_dict
()
new_state_dict
=
{}
for
k1
,
k2
in
zip
(
state_dict
.
keys
(),
params
.
keys
()):
if
list
(
state_dict
[
k1
].
shape
)
==
list
(
params
[
k2
].
shape
):
new_state_dict
[
k1
]
=
params
[
k2
]
else
:
print
(
f
"The shape of model params
{
k1
}
{
state_dict
[
k1
].
shape
}
not matched with loaded params
{
k2
}
{
params
[
k2
].
shape
}
!"
)
model
.
set_state_dict
(
new_state_dict
)
print
(
f
"load pretrain successful from
{
path
}
"
)
return
model
def
save_model
(
model
,
optimizer
,
...
...
tools/eval.py
浏览文件 @
2062b509
...
...
@@ -27,7 +27,7 @@ from ppocr.data import build_dataloader
from
ppocr.modeling.architectures
import
build_model
from
ppocr.postprocess
import
build_post_process
from
ppocr.metrics
import
build_metric
from
ppocr.utils.save_load
import
init_model
from
ppocr.utils.save_load
import
init_model
,
load_pretrained_params
from
ppocr.utils.utility
import
print_dict
import
tools.program
as
program
...
...
@@ -55,7 +55,10 @@ def main():
model
=
build_model
(
config
[
'Architecture'
])
use_srn
=
config
[
'Architecture'
][
'algorithm'
]
==
"SRN"
model_type
=
config
[
'Architecture'
][
'model_type'
]
if
"model_type"
in
config
[
'Architecture'
].
keys
():
model_type
=
config
[
'Architecture'
][
'model_type'
]
else
:
model_type
=
None
best_model_dict
=
init_model
(
config
,
model
)
if
len
(
best_model_dict
):
...
...
@@ -68,7 +71,7 @@ def main():
# start eval
metric
=
program
.
eval
(
model
,
valid_dataloader
,
post_process_class
,
eval_class
,
model_type
,
use_srn
)
eval_class
,
model_type
,
use_srn
)
logger
.
info
(
'metric eval ***************'
)
for
k
,
v
in
metric
.
items
():
logger
.
info
(
'{}:{}'
.
format
(
k
,
v
))
...
...
tools/program.py
浏览文件 @
2062b509
...
...
@@ -186,7 +186,10 @@ def train(config,
model
.
train
()
use_srn
=
config
[
'Architecture'
][
'algorithm'
]
==
"SRN"
model_type
=
config
[
'Architecture'
][
'model_type'
]
try
:
model_type
=
config
[
'Architecture'
][
'model_type'
]
except
:
model_type
=
None
if
'start_epoch'
in
best_model_dict
:
start_epoch
=
best_model_dict
[
'start_epoch'
]
...
...
tools/train.py
浏览文件 @
2062b509
...
...
@@ -98,7 +98,6 @@ def main(config, device, logger, vdl_writer):
eval_class
=
build_metric
(
config
[
'Metric'
])
# load pretrain model
pre_best_model_dict
=
load_dygraph_params
(
config
,
model
,
logger
,
optimizer
)
logger
.
info
(
'train dataloader has {} iters'
.
format
(
len
(
train_dataloader
)))
if
valid_dataloader
is
not
None
:
logger
.
info
(
'valid dataloader has {} iters'
.
format
(
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录