Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleOCR
提交
86dd21f0
P
PaddleOCR
项目概览
PaddlePaddle
/
PaddleOCR
大约 1 年 前同步成功
通知
1528
Star
32962
Fork
6643
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
108
列表
看板
标记
里程碑
合并请求
7
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleOCR
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
108
Issue
108
列表
看板
标记
里程碑
合并请求
7
合并请求
7
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
86dd21f0
编写于
9月 06, 2021
作者:
D
Double_V
提交者:
GitHub
9月 06, 2021
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #3894 from LDOUBLEV/fix_distill
fix config about det distill
上级
332cb26a
fd628d56
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
189 addition
and
39 deletion
+189
-39
configs/det/ch_ppocr_v2.1/ch_det_lite_train_cml_v2.1.yml
configs/det/ch_ppocr_v2.1/ch_det_lite_train_cml_v2.1.yml
+16
-19
configs/det/ch_ppocr_v2.1/ch_det_mv3_db_v2.1_student.yml
configs/det/ch_ppocr_v2.1/ch_det_mv3_db_v2.1_student.yml
+132
-0
tools/eval.py
tools/eval.py
+3
-3
tools/infer_det.py
tools/infer_det.py
+38
-17
未找到文件。
configs/det/ch_ppocr_v2.1/ch_det_lite_train_cml_v2.1.yml
浏览文件 @
86dd21f0
...
@@ -8,7 +8,7 @@ Global:
...
@@ -8,7 +8,7 @@ Global:
# evaluation is run every 5000 iterations after the 4000th iteration
# evaluation is run every 5000 iterations after the 4000th iteration
eval_batch_step
:
[
3000
,
2000
]
eval_batch_step
:
[
3000
,
2000
]
cal_metric_during_train
:
False
cal_metric_during_train
:
False
pretrained_model
:
./pretrain_models/
MobileNetV3_large_x0_5_pretrained
pretrained_model
:
./pretrain_models/
ch_ppocr_mobile_v2.1_det_distill_train/best_accuracy
checkpoints
:
checkpoints
:
save_inference_dir
:
save_inference_dir
:
use_visualdl
:
False
use_visualdl
:
False
...
@@ -19,30 +19,26 @@ Architecture:
...
@@ -19,30 +19,26 @@ Architecture:
name
:
DistillationModel
name
:
DistillationModel
algorithm
:
Distillation
algorithm
:
Distillation
Models
:
Models
:
Student
:
Teacher
:
pretrained
:
./pretrain_models/MobileNetV3_large_x0_5_pretrained
freeze_params
:
true
freeze_params
:
false
return_all_feats
:
false
return_all_feats
:
false
model_type
:
det
model_type
:
det
algorithm
:
DB
algorithm
:
DB
Transform
:
Backbone
:
Backbone
:
name
:
MobileNetV3
name
:
ResNet
scale
:
0.5
layers
:
18
model_name
:
large
disable_se
:
True
Neck
:
Neck
:
name
:
DBFPN
name
:
DBFPN
out_channels
:
9
6
out_channels
:
25
6
Head
:
Head
:
name
:
DBHead
name
:
DBHead
k
:
50
k
:
50
Student2
:
Student
:
pretrained
:
./pretrain_models/MobileNetV3_large_x0_5_pretrained
freeze_params
:
false
freeze_params
:
false
return_all_feats
:
false
return_all_feats
:
false
model_type
:
det
model_type
:
det
algorithm
:
DB
algorithm
:
DB
Transform
:
Backbone
:
Backbone
:
name
:
MobileNetV3
name
:
MobileNetV3
scale
:
0.5
scale
:
0.5
...
@@ -54,23 +50,24 @@ Architecture:
...
@@ -54,23 +50,24 @@ Architecture:
Head
:
Head
:
name
:
DBHead
name
:
DBHead
k
:
50
k
:
50
Teacher
:
Student2
:
pretrained
:
./pretrain_models/ch_ppocr_server_v2.0_det_train/best_accuracy
freeze_params
:
false
freeze_params
:
true
return_all_feats
:
false
return_all_feats
:
false
model_type
:
det
model_type
:
det
algorithm
:
DB
algorithm
:
DB
Transform
:
Transform
:
Backbone
:
Backbone
:
name
:
ResNet
name
:
MobileNetV3
layers
:
18
scale
:
0.5
model_name
:
large
disable_se
:
True
Neck
:
Neck
:
name
:
DBFPN
name
:
DBFPN
out_channels
:
25
6
out_channels
:
9
6
Head
:
Head
:
name
:
DBHead
name
:
DBHead
k
:
50
k
:
50
Loss
:
Loss
:
name
:
CombinedLoss
name
:
CombinedLoss
loss_config_list
:
loss_config_list
:
...
...
configs/det/ch_ppocr_v2.1/ch_det_mv3_db_v2.1_student.yml
0 → 100644
浏览文件 @
86dd21f0
Global
:
use_gpu
:
true
epoch_num
:
1200
log_smooth_window
:
20
print_batch_step
:
10
save_model_dir
:
./output/ch_db_mv3/
save_epoch_step
:
1200
# evaluation is run every 5000 iterations after the 4000th iteration
eval_batch_step
:
[
0
,
400
]
cal_metric_during_train
:
False
pretrained_model
:
./pretrain_models/student.pdparams
checkpoints
:
save_inference_dir
:
use_visualdl
:
False
infer_img
:
doc/imgs_en/img_10.jpg
save_res_path
:
./output/det_db/predicts_db.txt
Architecture
:
model_type
:
det
algorithm
:
DB
Transform
:
Backbone
:
name
:
MobileNetV3
scale
:
0.5
model_name
:
large
disable_se
:
True
Neck
:
name
:
DBFPN
out_channels
:
96
Head
:
name
:
DBHead
k
:
50
Loss
:
name
:
DBLoss
balance_loss
:
true
main_loss_type
:
DiceLoss
alpha
:
5
beta
:
10
ohem_ratio
:
3
Optimizer
:
name
:
Adam
beta1
:
0.9
beta2
:
0.999
lr
:
name
:
Cosine
learning_rate
:
0.001
warmup_epoch
:
2
regularizer
:
name
:
'
L2'
factor
:
0
PostProcess
:
name
:
DBPostProcess
thresh
:
0.3
box_thresh
:
0.6
max_candidates
:
1000
unclip_ratio
:
1.5
Metric
:
name
:
DetMetric
main_indicator
:
hmean
Train
:
dataset
:
name
:
SimpleDataSet
data_dir
:
./train_data/icdar2015/text_localization/
label_file_list
:
-
./train_data/icdar2015/text_localization/train_icdar2015_label.txt
ratio_list
:
[
1.0
]
transforms
:
-
DecodeImage
:
# load image
img_mode
:
BGR
channel_first
:
False
-
DetLabelEncode
:
# Class handling label
-
IaaAugment
:
augmenter_args
:
-
{
'
type'
:
Fliplr
,
'
args'
:
{
'
p'
:
0.5
}
}
-
{
'
type'
:
Affine
,
'
args'
:
{
'
rotate'
:
[
-10
,
10
]
}
}
-
{
'
type'
:
Resize
,
'
args'
:
{
'
size'
:
[
0.5
,
3
]
}
}
-
EastRandomCropData
:
size
:
[
960
,
960
]
max_tries
:
50
keep_ratio
:
true
-
MakeBorderMap
:
shrink_ratio
:
0.4
thresh_min
:
0.3
thresh_max
:
0.7
-
MakeShrinkMap
:
shrink_ratio
:
0.4
min_text_size
:
8
-
NormalizeImage
:
scale
:
1./255.
mean
:
[
0.485
,
0.456
,
0.406
]
std
:
[
0.229
,
0.224
,
0.225
]
order
:
'
hwc'
-
ToCHWImage
:
-
KeepKeys
:
keep_keys
:
[
'
image'
,
'
threshold_map'
,
'
threshold_mask'
,
'
shrink_map'
,
'
shrink_mask'
]
# the order of the dataloader list
loader
:
shuffle
:
True
drop_last
:
False
batch_size_per_card
:
8
num_workers
:
4
Eval
:
dataset
:
name
:
SimpleDataSet
data_dir
:
./train_data/icdar2015/text_localization/
label_file_list
:
-
./train_data/icdar2015/text_localization/test_icdar2015_label.txt
transforms
:
-
DecodeImage
:
# load image
img_mode
:
BGR
channel_first
:
False
-
DetLabelEncode
:
# Class handling label
-
DetResizeForTest
:
# image_shape: [736, 1280]
-
NormalizeImage
:
scale
:
1./255.
mean
:
[
0.485
,
0.456
,
0.406
]
std
:
[
0.229
,
0.224
,
0.225
]
order
:
'
hwc'
-
ToCHWImage
:
-
KeepKeys
:
keep_keys
:
[
'
image'
,
'
shape'
,
'
polys'
,
'
ignore_tags'
]
loader
:
shuffle
:
False
drop_last
:
False
batch_size_per_card
:
1
# must be 1
num_workers
:
2
tools/eval.py
浏览文件 @
86dd21f0
...
@@ -27,7 +27,7 @@ from ppocr.data import build_dataloader
...
@@ -27,7 +27,7 @@ from ppocr.data import build_dataloader
from
ppocr.modeling.architectures
import
build_model
from
ppocr.modeling.architectures
import
build_model
from
ppocr.postprocess
import
build_post_process
from
ppocr.postprocess
import
build_post_process
from
ppocr.metrics
import
build_metric
from
ppocr.metrics
import
build_metric
from
ppocr.utils.save_load
import
init_model
,
load_
pretrained
_params
from
ppocr.utils.save_load
import
init_model
,
load_
dygraph
_params
from
ppocr.utils.utility
import
print_dict
from
ppocr.utils.utility
import
print_dict
import
tools.program
as
program
import
tools.program
as
program
...
@@ -60,7 +60,7 @@ def main():
...
@@ -60,7 +60,7 @@ def main():
else
:
else
:
model_type
=
None
model_type
=
None
best_model_dict
=
init_model
(
config
,
model
)
best_model_dict
=
load_dygraph_params
(
config
,
model
,
logger
,
None
)
if
len
(
best_model_dict
):
if
len
(
best_model_dict
):
logger
.
info
(
'metric in ckpt ***************'
)
logger
.
info
(
'metric in ckpt ***************'
)
for
k
,
v
in
best_model_dict
.
items
():
for
k
,
v
in
best_model_dict
.
items
():
...
@@ -71,7 +71,7 @@ def main():
...
@@ -71,7 +71,7 @@ def main():
# start eval
# start eval
metric
=
program
.
eval
(
model
,
valid_dataloader
,
post_process_class
,
metric
=
program
.
eval
(
model
,
valid_dataloader
,
post_process_class
,
eval_class
,
model_type
,
use_srn
)
eval_class
,
model_type
,
use_srn
)
logger
.
info
(
'metric eval ***************'
)
logger
.
info
(
'metric eval ***************'
)
for
k
,
v
in
metric
.
items
():
for
k
,
v
in
metric
.
items
():
logger
.
info
(
'{}:{}'
.
format
(
k
,
v
))
logger
.
info
(
'{}:{}'
.
format
(
k
,
v
))
...
...
tools/infer_det.py
浏览文件 @
86dd21f0
...
@@ -34,23 +34,21 @@ import paddle
...
@@ -34,23 +34,21 @@ import paddle
from
ppocr.data
import
create_operators
,
transform
from
ppocr.data
import
create_operators
,
transform
from
ppocr.modeling.architectures
import
build_model
from
ppocr.modeling.architectures
import
build_model
from
ppocr.postprocess
import
build_post_process
from
ppocr.postprocess
import
build_post_process
from
ppocr.utils.save_load
import
init_model
from
ppocr.utils.save_load
import
init_model
,
load_dygraph_params
from
ppocr.utils.utility
import
get_image_file_list
from
ppocr.utils.utility
import
get_image_file_list
import
tools.program
as
program
import
tools.program
as
program
def
draw_det_res
(
dt_boxes
,
config
,
img
,
img_name
):
def
draw_det_res
(
dt_boxes
,
config
,
img
,
img_name
,
save_path
):
if
len
(
dt_boxes
)
>
0
:
if
len
(
dt_boxes
)
>
0
:
import
cv2
import
cv2
src_im
=
img
src_im
=
img
for
box
in
dt_boxes
:
for
box
in
dt_boxes
:
box
=
box
.
astype
(
np
.
int32
).
reshape
((
-
1
,
1
,
2
))
box
=
box
.
astype
(
np
.
int32
).
reshape
((
-
1
,
1
,
2
))
cv2
.
polylines
(
src_im
,
[
box
],
True
,
color
=
(
255
,
255
,
0
),
thickness
=
2
)
cv2
.
polylines
(
src_im
,
[
box
],
True
,
color
=
(
255
,
255
,
0
),
thickness
=
2
)
save_det_path
=
os
.
path
.
dirname
(
config
[
'Global'
][
if
not
os
.
path
.
exists
(
save_path
):
'save_res_path'
])
+
"/det_results/"
os
.
makedirs
(
save_path
)
if
not
os
.
path
.
exists
(
save_det_path
):
save_path
=
os
.
path
.
join
(
save_path
,
os
.
path
.
basename
(
img_name
))
os
.
makedirs
(
save_det_path
)
save_path
=
os
.
path
.
join
(
save_det_path
,
os
.
path
.
basename
(
img_name
))
cv2
.
imwrite
(
save_path
,
src_im
)
cv2
.
imwrite
(
save_path
,
src_im
)
logger
.
info
(
"The detected Image saved in {}"
.
format
(
save_path
))
logger
.
info
(
"The detected Image saved in {}"
.
format
(
save_path
))
...
@@ -61,8 +59,7 @@ def main():
...
@@ -61,8 +59,7 @@ def main():
# build model
# build model
model
=
build_model
(
config
[
'Architecture'
])
model
=
build_model
(
config
[
'Architecture'
])
init_model
(
config
,
model
)
_
=
load_dygraph_params
(
config
,
model
,
logger
,
None
)
# build post process
# build post process
post_process_class
=
build_post_process
(
config
[
'PostProcess'
])
post_process_class
=
build_post_process
(
config
[
'PostProcess'
])
...
@@ -96,17 +93,41 @@ def main():
...
@@ -96,17 +93,41 @@ def main():
images
=
paddle
.
to_tensor
(
images
)
images
=
paddle
.
to_tensor
(
images
)
preds
=
model
(
images
)
preds
=
model
(
images
)
post_result
=
post_process_class
(
preds
,
shape_list
)
post_result
=
post_process_class
(
preds
,
shape_list
)
boxes
=
post_result
[
0
][
'points'
]
# write result
src_img
=
cv2
.
imread
(
file
)
dt_boxes_json
=
[]
dt_boxes_json
=
[]
for
box
in
boxes
:
# parser boxes if post_result is dict
tmp_json
=
{
"transcription"
:
""
}
if
isinstance
(
post_result
,
dict
):
tmp_json
[
'points'
]
=
box
.
tolist
()
det_box_json
=
{}
dt_boxes_json
.
append
(
tmp_json
)
for
k
in
post_result
.
keys
():
boxes
=
post_result
[
k
][
0
][
'points'
]
dt_boxes_list
=
[]
for
box
in
boxes
:
tmp_json
=
{
"transcription"
:
""
}
tmp_json
[
'points'
]
=
box
.
tolist
()
dt_boxes_list
.
append
(
tmp_json
)
det_box_json
[
k
]
=
dt_boxes_list
save_det_path
=
os
.
path
.
dirname
(
config
[
'Global'
][
'save_res_path'
])
+
"/det_results_{}/"
.
format
(
k
)
draw_det_res
(
boxes
,
config
,
src_img
,
file
,
save_det_path
)
else
:
boxes
=
post_result
[
0
][
'points'
]
dt_boxes_json
=
[]
# write result
for
box
in
boxes
:
tmp_json
=
{
"transcription"
:
""
}
tmp_json
[
'points'
]
=
box
.
tolist
()
dt_boxes_json
.
append
(
tmp_json
)
save_det_path
=
os
.
path
.
dirname
(
config
[
'Global'
][
'save_res_path'
])
+
"/det_results/"
draw_det_res
(
boxes
,
config
,
src_img
,
file
,
save_det_path
)
otstr
=
file
+
"
\t
"
+
json
.
dumps
(
dt_boxes_json
)
+
"
\n
"
otstr
=
file
+
"
\t
"
+
json
.
dumps
(
dt_boxes_json
)
+
"
\n
"
fout
.
write
(
otstr
.
encode
())
fout
.
write
(
otstr
.
encode
())
src_img
=
cv2
.
imread
(
file
)
draw_det_res
(
boxes
,
config
,
src_img
,
file
)
save_det_path
=
os
.
path
.
dirname
(
config
[
'Global'
][
'save_res_path'
])
+
"/det_results/"
draw_det_res
(
boxes
,
config
,
src_img
,
file
,
save_det_path
)
logger
.
info
(
"success!"
)
logger
.
info
(
"success!"
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录