Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleOCR
提交
48d85379
P
PaddleOCR
项目概览
PaddlePaddle
/
PaddleOCR
大约 1 年 前同步成功
通知
1528
Star
32962
Fork
6643
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
108
列表
看板
标记
里程碑
合并请求
7
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleOCR
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
108
Issue
108
列表
看板
标记
里程碑
合并请求
7
合并请求
7
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
48d85379
编写于
6月 05, 2021
作者:
littletomatodonkey
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
rm load_dyg_pretrain
上级
bd1820b7
变更
10
显示空白变更内容
内联
并排
Showing
10 changed file
with
27 addition
and
29 deletion
+27
-29
configs/rec/ch_ppocr_v2.1/rec_chinese_lite_train_distillation_v2.1.yml
...h_ppocr_v2.1/rec_chinese_lite_train_distillation_v2.1.yml
+8
-8
ppocr/modeling/architectures/distillation_model.py
ppocr/modeling/architectures/distillation_model.py
+2
-2
ppocr/utils/save_load.py
ppocr/utils/save_load.py
+10
-12
tools/eval.py
tools/eval.py
+1
-1
tools/export_model.py
tools/export_model.py
+1
-1
tools/infer_cls.py
tools/infer_cls.py
+1
-1
tools/infer_det.py
tools/infer_det.py
+1
-1
tools/infer_e2e.py
tools/infer_e2e.py
+1
-1
tools/infer_rec.py
tools/infer_rec.py
+1
-1
tools/train.py
tools/train.py
+1
-1
未找到文件。
configs/rec/ch_ppocr_v2.
0
/rec_chinese_lite_train_distillation_v2.1.yml
→
configs/rec/ch_ppocr_v2.
1
/rec_chinese_lite_train_distillation_v2.1.yml
浏览文件 @
48d85379
...
@@ -8,9 +8,9 @@ Global:
...
@@ -8,9 +8,9 @@ Global:
save_epoch_step
:
3
save_epoch_step
:
3
eval_batch_step
:
[
0
,
2000
]
eval_batch_step
:
[
0
,
2000
]
cal_metric_during_train
:
true
cal_metric_during_train
:
true
pretrained_model
:
null
pretrained_model
:
checkpoints
:
null
checkpoints
:
save_inference_dir
:
null
save_inference_dir
:
use_visualdl
:
false
use_visualdl
:
false
infer_img
:
doc/imgs_words/ch/word_1.jpg
infer_img
:
doc/imgs_words/ch/word_1.jpg
character_dict_path
:
ppocr/utils/ppocr_keys_v1.txt
character_dict_path
:
ppocr/utils/ppocr_keys_v1.txt
...
@@ -38,7 +38,7 @@ Architecture:
...
@@ -38,7 +38,7 @@ Architecture:
algorithm
:
Distillation
algorithm
:
Distillation
Models
:
Models
:
Student
:
Student
:
pretrained
:
null
pretrained
:
freeze_params
:
false
freeze_params
:
false
return_all_feats
:
true
return_all_feats
:
true
model_type
:
rec
model_type
:
rec
...
@@ -57,7 +57,7 @@ Architecture:
...
@@ -57,7 +57,7 @@ Architecture:
name
:
CTCHead
name
:
CTCHead
fc_decay
:
0.00001
fc_decay
:
0.00001
Teacher
:
Teacher
:
pretrained
:
null
pretrained
:
freeze_params
:
false
freeze_params
:
false
return_all_feats
:
true
return_all_feats
:
true
model_type
:
rec
model_type
:
rec
...
@@ -118,8 +118,8 @@ Train:
...
@@ -118,8 +118,8 @@ Train:
-
DecodeImage
:
-
DecodeImage
:
img_mode
:
BGR
img_mode
:
BGR
channel_first
:
false
channel_first
:
false
-
RecAug
:
null
-
RecAug
:
-
CTCLabelEncode
:
null
-
CTCLabelEncode
:
-
RecResizeImg
:
-
RecResizeImg
:
image_shape
:
[
3
,
32
,
320
]
image_shape
:
[
3
,
32
,
320
]
-
KeepKeys
:
-
KeepKeys
:
...
@@ -143,7 +143,7 @@ Eval:
...
@@ -143,7 +143,7 @@ Eval:
-
DecodeImage
:
-
DecodeImage
:
img_mode
:
BGR
img_mode
:
BGR
channel_first
:
false
channel_first
:
false
-
CTCLabelEncode
:
null
-
CTCLabelEncode
:
-
RecResizeImg
:
-
RecResizeImg
:
image_shape
:
[
3
,
32
,
320
]
image_shape
:
[
3
,
32
,
320
]
-
KeepKeys
:
-
KeepKeys
:
...
...
ppocr/modeling/architectures/distillation_model.py
浏览文件 @
48d85379
...
@@ -21,7 +21,7 @@ from ppocr.modeling.backbones import build_backbone
...
@@ -21,7 +21,7 @@ from ppocr.modeling.backbones import build_backbone
from
ppocr.modeling.necks
import
build_neck
from
ppocr.modeling.necks
import
build_neck
from
ppocr.modeling.heads
import
build_head
from
ppocr.modeling.heads
import
build_head
from
.base_model
import
BaseModel
from
.base_model
import
BaseModel
from
ppocr.utils.save_load
import
load_dygraph_pretrain
from
ppocr.utils.save_load
import
init_model
__all__
=
[
'DistillationModel'
]
__all__
=
[
'DistillationModel'
]
...
@@ -46,7 +46,7 @@ class DistillationModel(nn.Layer):
...
@@ -46,7 +46,7 @@ class DistillationModel(nn.Layer):
pretrained
=
model_config
.
pop
(
"pretrained"
)
pretrained
=
model_config
.
pop
(
"pretrained"
)
model
=
BaseModel
(
model_config
)
model
=
BaseModel
(
model_config
)
if
pretrained
is
not
None
:
if
pretrained
is
not
None
:
load_dygraph_pretrain
(
model
,
path
=
pretrained
)
init_model
(
model
,
path
=
pretrained
)
if
freeze_params
:
if
freeze_params
:
for
param
in
model
.
parameters
():
for
param
in
model
.
parameters
():
param
.
trainable
=
False
param
.
trainable
=
False
...
...
ppocr/utils/save_load.py
浏览文件 @
48d85379
...
@@ -23,6 +23,8 @@ import six
...
@@ -23,6 +23,8 @@ import six
import
paddle
import
paddle
from
ppocr.utils.logging
import
get_logger
__all__
=
[
'init_model'
,
'save_model'
,
'load_dygraph_pretrain'
]
__all__
=
[
'init_model'
,
'save_model'
,
'load_dygraph_pretrain'
]
...
@@ -42,19 +44,11 @@ def _mkdir_if_not_exist(path, logger):
...
@@ -42,19 +44,11 @@ def _mkdir_if_not_exist(path, logger):
raise
OSError
(
'Failed to mkdir {}'
.
format
(
path
))
raise
OSError
(
'Failed to mkdir {}'
.
format
(
path
))
def
load_dygraph_pretrain
(
model
,
logger
=
None
,
path
=
None
):
def
init_model
(
config
,
model
,
optimizer
=
None
,
lr_scheduler
=
None
):
if
not
(
os
.
path
.
isdir
(
path
)
or
os
.
path
.
exists
(
path
+
'.pdparams'
)):
raise
ValueError
(
"Model pretrain path {} does not "
"exists."
.
format
(
path
))
param_state_dict
=
paddle
.
load
(
path
+
'.pdparams'
)
model
.
set_state_dict
(
param_state_dict
)
return
def
init_model
(
config
,
model
,
logger
,
optimizer
=
None
,
lr_scheduler
=
None
):
"""
"""
load model from checkpoint or pretrained_model
load model from checkpoint or pretrained_model
"""
"""
logger
=
get_logger
()
global_config
=
config
[
'Global'
]
global_config
=
config
[
'Global'
]
checkpoints
=
global_config
.
get
(
'checkpoints'
)
checkpoints
=
global_config
.
get
(
'checkpoints'
)
pretrained_model
=
global_config
.
get
(
'pretrained_model'
)
pretrained_model
=
global_config
.
get
(
'pretrained_model'
)
...
@@ -77,13 +71,17 @@ def init_model(config, model, logger, optimizer=None, lr_scheduler=None):
...
@@ -77,13 +71,17 @@ def init_model(config, model, logger, optimizer=None, lr_scheduler=None):
best_model_dict
=
states_dict
.
get
(
'best_model_dict'
,
{})
best_model_dict
=
states_dict
.
get
(
'best_model_dict'
,
{})
if
'epoch'
in
states_dict
:
if
'epoch'
in
states_dict
:
best_model_dict
[
'start_epoch'
]
=
states_dict
[
'epoch'
]
+
1
best_model_dict
[
'start_epoch'
]
=
states_dict
[
'epoch'
]
+
1
logger
.
info
(
"resume from {}"
.
format
(
checkpoints
))
logger
.
info
(
"resume from {}"
.
format
(
checkpoints
))
elif
pretrained_model
:
elif
pretrained_model
:
if
not
isinstance
(
pretrained_model
,
list
):
if
not
isinstance
(
pretrained_model
,
list
):
pretrained_model
=
[
pretrained_model
]
pretrained_model
=
[
pretrained_model
]
for
pretrained
in
pretrained_model
:
for
pretrained
in
pretrained_model
:
load_dygraph_pretrain
(
model
,
logger
,
path
=
pretrained
)
if
not
(
os
.
path
.
isdir
(
pretrained
)
or
os
.
path
.
exists
(
pretrained
+
'.pdparams'
)):
raise
ValueError
(
"Model pretrain path {} does not "
"exists."
.
format
(
pretrained
))
param_state_dict
=
paddle
.
load
(
pretrained
+
'.pdparams'
)
model
.
set_state_dict
(
param_state_dict
)
logger
.
info
(
"load pretrained model from {}"
.
format
(
logger
.
info
(
"load pretrained model from {}"
.
format
(
pretrained_model
))
pretrained_model
))
else
:
else
:
...
...
tools/eval.py
浏览文件 @
48d85379
...
@@ -49,7 +49,7 @@ def main():
...
@@ -49,7 +49,7 @@ def main():
model
=
build_model
(
config
[
'Architecture'
])
model
=
build_model
(
config
[
'Architecture'
])
use_srn
=
config
[
'Architecture'
][
'algorithm'
]
==
"SRN"
use_srn
=
config
[
'Architecture'
][
'algorithm'
]
==
"SRN"
best_model_dict
=
init_model
(
config
,
model
,
logger
)
best_model_dict
=
init_model
(
config
,
model
)
if
len
(
best_model_dict
):
if
len
(
best_model_dict
):
logger
.
info
(
'metric in ckpt ***************'
)
logger
.
info
(
'metric in ckpt ***************'
)
for
k
,
v
in
best_model_dict
.
items
():
for
k
,
v
in
best_model_dict
.
items
():
...
...
tools/export_model.py
浏览文件 @
48d85379
...
@@ -95,7 +95,7 @@ def main():
...
@@ -95,7 +95,7 @@ def main():
else
:
# base rec model
else
:
# base rec model
config
[
"Architecture"
][
"Head"
][
"out_channels"
]
=
char_num
config
[
"Architecture"
][
"Head"
][
"out_channels"
]
=
char_num
model
=
build_model
(
config
[
"Architecture"
])
model
=
build_model
(
config
[
"Architecture"
])
init_model
(
config
,
model
,
logger
)
init_model
(
config
,
model
)
model
.
eval
()
model
.
eval
()
save_path
=
config
[
"Global"
][
"save_inference_dir"
]
save_path
=
config
[
"Global"
][
"save_inference_dir"
]
...
...
tools/infer_cls.py
浏览文件 @
48d85379
...
@@ -47,7 +47,7 @@ def main():
...
@@ -47,7 +47,7 @@ def main():
# build model
# build model
model
=
build_model
(
config
[
'Architecture'
])
model
=
build_model
(
config
[
'Architecture'
])
init_model
(
config
,
model
,
logger
)
init_model
(
config
,
model
)
# create data ops
# create data ops
transforms
=
[]
transforms
=
[]
...
...
tools/infer_det.py
浏览文件 @
48d85379
...
@@ -61,7 +61,7 @@ def main():
...
@@ -61,7 +61,7 @@ def main():
# build model
# build model
model
=
build_model
(
config
[
'Architecture'
])
model
=
build_model
(
config
[
'Architecture'
])
init_model
(
config
,
model
,
logger
)
init_model
(
config
,
model
)
# build post process
# build post process
post_process_class
=
build_post_process
(
config
[
'PostProcess'
])
post_process_class
=
build_post_process
(
config
[
'PostProcess'
])
...
...
tools/infer_e2e.py
浏览文件 @
48d85379
...
@@ -68,7 +68,7 @@ def main():
...
@@ -68,7 +68,7 @@ def main():
# build model
# build model
model
=
build_model
(
config
[
'Architecture'
])
model
=
build_model
(
config
[
'Architecture'
])
init_model
(
config
,
model
,
logger
)
init_model
(
config
,
model
)
# build post process
# build post process
post_process_class
=
build_post_process
(
config
[
'PostProcess'
],
post_process_class
=
build_post_process
(
config
[
'PostProcess'
],
...
...
tools/infer_rec.py
浏览文件 @
48d85379
...
@@ -58,7 +58,7 @@ def main():
...
@@ -58,7 +58,7 @@ def main():
model
=
build_model
(
config
[
'Architecture'
])
model
=
build_model
(
config
[
'Architecture'
])
init_model
(
config
,
model
,
logger
)
init_model
(
config
,
model
)
# create data ops
# create data ops
transforms
=
[]
transforms
=
[]
...
...
tools/train.py
浏览文件 @
48d85379
...
@@ -97,7 +97,7 @@ def main(config, device, logger, vdl_writer):
...
@@ -97,7 +97,7 @@ def main(config, device, logger, vdl_writer):
# build metric
# build metric
eval_class
=
build_metric
(
config
[
'Metric'
])
eval_class
=
build_metric
(
config
[
'Metric'
])
# load pretrain model
# load pretrain model
pre_best_model_dict
=
init_model
(
config
,
model
,
logger
,
optimizer
)
pre_best_model_dict
=
init_model
(
config
,
model
,
optimizer
)
logger
.
info
(
'train dataloader has {} iters'
.
format
(
len
(
train_dataloader
)))
logger
.
info
(
'train dataloader has {} iters'
.
format
(
len
(
train_dataloader
)))
if
valid_dataloader
is
not
None
:
if
valid_dataloader
is
not
None
:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录