Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleOCR
提交
ae4167dc
P
PaddleOCR
项目概览
PaddlePaddle
/
PaddleOCR
大约 1 年 前同步成功
通知
1528
Star
32962
Fork
6643
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
108
列表
看板
标记
里程碑
合并请求
7
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleOCR
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
108
Issue
108
列表
看板
标记
里程碑
合并请求
7
合并请求
7
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
ae4167dc
编写于
11月 12, 2021
作者:
Z
zhoujun
提交者:
GitHub
11月 12, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
merge init_model and load_dygraph_params to load_model (#4623)
* merge init_model and load_dygraph_params to load_model
上级
1417a3c2
变更
16
隐藏空白更改
内联
并排
Showing
16 changed file
with
48 addition
and
90 deletion
+48
-90
deploy/slim/prune/export_prune_model.py
deploy/slim/prune/export_prune_model.py
+2
-2
deploy/slim/prune/sensitivity_anal.py
deploy/slim/prune/sensitivity_anal.py
+2
-2
deploy/slim/quantization/export_model.py
deploy/slim/quantization/export_model.py
+2
-2
deploy/slim/quantization/quant.py
deploy/slim/quantization/quant.py
+2
-2
deploy/slim/quantization/quant_kl.py
deploy/slim/quantization/quant_kl.py
+1
-1
ppocr/modeling/architectures/distillation_model.py
ppocr/modeling/architectures/distillation_model.py
+1
-1
ppocr/utils/save_load.py
ppocr/utils/save_load.py
+18
-58
tools/eval.py
tools/eval.py
+2
-2
tools/export_center.py
tools/export_center.py
+2
-2
tools/export_model.py
tools/export_model.py
+2
-2
tools/infer_cls.py
tools/infer_cls.py
+2
-2
tools/infer_det.py
tools/infer_det.py
+2
-2
tools/infer_e2e.py
tools/infer_e2e.py
+2
-2
tools/infer_rec.py
tools/infer_rec.py
+3
-5
tools/infer_table.py
tools/infer_table.py
+3
-3
tools/train.py
tools/train.py
+2
-2
未找到文件。
deploy/slim/prune/export_prune_model.py
浏览文件 @
ae4167dc
...
@@ -30,7 +30,7 @@ from ppocr.modeling.architectures import build_model
...
@@ -30,7 +30,7 @@ from ppocr.modeling.architectures import build_model
from
ppocr.postprocess
import
build_post_process
from
ppocr.postprocess
import
build_post_process
from
ppocr.metrics
import
build_metric
from
ppocr.metrics
import
build_metric
from
ppocr.utils.save_load
import
init
_model
from
ppocr.utils.save_load
import
load
_model
import
tools.program
as
program
import
tools.program
as
program
...
@@ -89,7 +89,7 @@ def main(config, device, logger, vdl_writer):
...
@@ -89,7 +89,7 @@ def main(config, device, logger, vdl_writer):
logger
.
info
(
f
"FLOPs after pruning:
{
flops
}
"
)
logger
.
info
(
f
"FLOPs after pruning:
{
flops
}
"
)
# load pretrain model
# load pretrain model
pre_best_model_dict
=
init_model
(
config
,
model
,
logger
,
None
)
load_model
(
config
,
model
)
metric
=
program
.
eval
(
model
,
valid_dataloader
,
post_process_class
,
metric
=
program
.
eval
(
model
,
valid_dataloader
,
post_process_class
,
eval_class
)
eval_class
)
logger
.
info
(
f
"metric['hmean']:
{
metric
[
'hmean'
]
}
"
)
logger
.
info
(
f
"metric['hmean']:
{
metric
[
'hmean'
]
}
"
)
...
...
deploy/slim/prune/sensitivity_anal.py
浏览文件 @
ae4167dc
...
@@ -32,7 +32,7 @@ from ppocr.losses import build_loss
...
@@ -32,7 +32,7 @@ from ppocr.losses import build_loss
from
ppocr.optimizer
import
build_optimizer
from
ppocr.optimizer
import
build_optimizer
from
ppocr.postprocess
import
build_post_process
from
ppocr.postprocess
import
build_post_process
from
ppocr.metrics
import
build_metric
from
ppocr.metrics
import
build_metric
from
ppocr.utils.save_load
import
init
_model
from
ppocr.utils.save_load
import
load
_model
import
tools.program
as
program
import
tools.program
as
program
dist
.
get_world_size
()
dist
.
get_world_size
()
...
@@ -94,7 +94,7 @@ def main(config, device, logger, vdl_writer):
...
@@ -94,7 +94,7 @@ def main(config, device, logger, vdl_writer):
# build metric
# build metric
eval_class
=
build_metric
(
config
[
'Metric'
])
eval_class
=
build_metric
(
config
[
'Metric'
])
# load pretrain model
# load pretrain model
pre_best_model_dict
=
init_model
(
config
,
model
,
logger
,
optimizer
)
pre_best_model_dict
=
load_model
(
config
,
model
,
optimizer
)
logger
.
info
(
'train dataloader has {} iters, valid dataloader has {} iters'
.
logger
.
info
(
'train dataloader has {} iters, valid dataloader has {} iters'
.
format
(
len
(
train_dataloader
),
len
(
valid_dataloader
)))
format
(
len
(
train_dataloader
),
len
(
valid_dataloader
)))
...
...
deploy/slim/quantization/export_model.py
浏览文件 @
ae4167dc
...
@@ -28,7 +28,7 @@ from paddle.jit import to_static
...
@@ -28,7 +28,7 @@ from paddle.jit import to_static
from
ppocr.modeling.architectures
import
build_model
from
ppocr.modeling.architectures
import
build_model
from
ppocr.postprocess
import
build_post_process
from
ppocr.postprocess
import
build_post_process
from
ppocr.utils.save_load
import
init
_model
from
ppocr.utils.save_load
import
load
_model
from
ppocr.utils.logging
import
get_logger
from
ppocr.utils.logging
import
get_logger
from
tools.program
import
load_config
,
merge_config
,
ArgsParser
from
tools.program
import
load_config
,
merge_config
,
ArgsParser
from
ppocr.metrics
import
build_metric
from
ppocr.metrics
import
build_metric
...
@@ -101,7 +101,7 @@ def main():
...
@@ -101,7 +101,7 @@ def main():
quanter
=
QAT
(
config
=
quant_config
)
quanter
=
QAT
(
config
=
quant_config
)
quanter
.
quantize
(
model
)
quanter
.
quantize
(
model
)
init
_model
(
config
,
model
)
load
_model
(
config
,
model
)
model
.
eval
()
model
.
eval
()
# build metric
# build metric
...
...
deploy/slim/quantization/quant.py
浏览文件 @
ae4167dc
...
@@ -37,7 +37,7 @@ from ppocr.losses import build_loss
...
@@ -37,7 +37,7 @@ from ppocr.losses import build_loss
from
ppocr.optimizer
import
build_optimizer
from
ppocr.optimizer
import
build_optimizer
from
ppocr.postprocess
import
build_post_process
from
ppocr.postprocess
import
build_post_process
from
ppocr.metrics
import
build_metric
from
ppocr.metrics
import
build_metric
from
ppocr.utils.save_load
import
init
_model
from
ppocr.utils.save_load
import
load
_model
import
tools.program
as
program
import
tools.program
as
program
from
paddleslim.dygraph.quant
import
QAT
from
paddleslim.dygraph.quant
import
QAT
...
@@ -137,7 +137,7 @@ def main(config, device, logger, vdl_writer):
...
@@ -137,7 +137,7 @@ def main(config, device, logger, vdl_writer):
# build metric
# build metric
eval_class
=
build_metric
(
config
[
'Metric'
])
eval_class
=
build_metric
(
config
[
'Metric'
])
# load pretrain model
# load pretrain model
pre_best_model_dict
=
init_model
(
config
,
model
,
logger
,
optimizer
)
pre_best_model_dict
=
load_model
(
config
,
model
,
optimizer
)
logger
.
info
(
'train dataloader has {} iters, valid dataloader has {} iters'
.
logger
.
info
(
'train dataloader has {} iters, valid dataloader has {} iters'
.
format
(
len
(
train_dataloader
),
len
(
valid_dataloader
)))
format
(
len
(
train_dataloader
),
len
(
valid_dataloader
)))
...
...
deploy/slim/quantization/quant_kl.py
浏览文件 @
ae4167dc
...
@@ -37,7 +37,7 @@ from ppocr.losses import build_loss
...
@@ -37,7 +37,7 @@ from ppocr.losses import build_loss
from
ppocr.optimizer
import
build_optimizer
from
ppocr.optimizer
import
build_optimizer
from
ppocr.postprocess
import
build_post_process
from
ppocr.postprocess
import
build_post_process
from
ppocr.metrics
import
build_metric
from
ppocr.metrics
import
build_metric
from
ppocr.utils.save_load
import
init
_model
from
ppocr.utils.save_load
import
load
_model
import
tools.program
as
program
import
tools.program
as
program
import
paddleslim
import
paddleslim
from
paddleslim.dygraph.quant
import
QAT
from
paddleslim.dygraph.quant
import
QAT
...
...
ppocr/modeling/architectures/distillation_model.py
浏览文件 @
ae4167dc
...
@@ -21,7 +21,7 @@ from ppocr.modeling.backbones import build_backbone
...
@@ -21,7 +21,7 @@ from ppocr.modeling.backbones import build_backbone
from
ppocr.modeling.necks
import
build_neck
from
ppocr.modeling.necks
import
build_neck
from
ppocr.modeling.heads
import
build_head
from
ppocr.modeling.heads
import
build_head
from
.base_model
import
BaseModel
from
.base_model
import
BaseModel
from
ppocr.utils.save_load
import
init_model
,
load_pretrained_params
from
ppocr.utils.save_load
import
load_pretrained_params
__all__
=
[
'DistillationModel'
]
__all__
=
[
'DistillationModel'
]
...
...
ppocr/utils/save_load.py
浏览文件 @
ae4167dc
...
@@ -25,7 +25,7 @@ import paddle
...
@@ -25,7 +25,7 @@ import paddle
from
ppocr.utils.logging
import
get_logger
from
ppocr.utils.logging
import
get_logger
__all__
=
[
'
init_model'
,
'save_model'
,
'load_dygraph_params
'
]
__all__
=
[
'
load_model
'
]
def
_mkdir_if_not_exist
(
path
,
logger
):
def
_mkdir_if_not_exist
(
path
,
logger
):
...
@@ -44,7 +44,7 @@ def _mkdir_if_not_exist(path, logger):
...
@@ -44,7 +44,7 @@ def _mkdir_if_not_exist(path, logger):
raise
OSError
(
'Failed to mkdir {}'
.
format
(
path
))
raise
OSError
(
'Failed to mkdir {}'
.
format
(
path
))
def
init_model
(
config
,
model
,
optimizer
=
None
,
lr_schedul
er
=
None
):
def
load_model
(
config
,
model
,
optimiz
er
=
None
):
"""
"""
load model from checkpoint or pretrained_model
load model from checkpoint or pretrained_model
"""
"""
...
@@ -54,15 +54,14 @@ def init_model(config, model, optimizer=None, lr_scheduler=None):
...
@@ -54,15 +54,14 @@ def init_model(config, model, optimizer=None, lr_scheduler=None):
pretrained_model
=
global_config
.
get
(
'pretrained_model'
)
pretrained_model
=
global_config
.
get
(
'pretrained_model'
)
best_model_dict
=
{}
best_model_dict
=
{}
if
checkpoints
:
if
checkpoints
:
assert
os
.
path
.
exists
(
checkpoints
+
".pdparams"
),
\
if
checkpoints
.
endswith
(
'pdparams'
):
"Given dir {}.pdparams not exist."
.
format
(
checkpoints
)
checkpoints
=
checkpoints
.
replace
(
'.pdparams'
,
''
)
assert
os
.
path
.
exists
(
checkpoints
+
".pdopt"
),
\
assert
os
.
path
.
exists
(
checkpoints
+
".pdopt"
),
\
"Given dir {}.pdopt not exist."
.
format
(
checkpoints
)
f
"The
{
checkpoints
}
.pdopt does not exists!"
para_dict
=
paddle
.
load
(
checkpoints
+
'.pdparams'
)
load_pretrained_params
(
model
,
checkpoints
)
opti_dict
=
paddle
.
load
(
checkpoints
+
'.pdopt'
)
optim_dict
=
paddle
.
load
(
checkpoints
+
'.pdopt'
)
model
.
set_state_dict
(
para_dict
)
if
optimizer
is
not
None
:
if
optimizer
is
not
None
:
optimizer
.
set_state_dict
(
opti_dict
)
optimizer
.
set_state_dict
(
opti
m
_dict
)
if
os
.
path
.
exists
(
checkpoints
+
'.states'
):
if
os
.
path
.
exists
(
checkpoints
+
'.states'
):
with
open
(
checkpoints
+
'.states'
,
'rb'
)
as
f
:
with
open
(
checkpoints
+
'.states'
,
'rb'
)
as
f
:
...
@@ -73,70 +72,31 @@ def init_model(config, model, optimizer=None, lr_scheduler=None):
...
@@ -73,70 +72,31 @@ def init_model(config, model, optimizer=None, lr_scheduler=None):
best_model_dict
[
'start_epoch'
]
=
states_dict
[
'epoch'
]
+
1
best_model_dict
[
'start_epoch'
]
=
states_dict
[
'epoch'
]
+
1
logger
.
info
(
"resume from {}"
.
format
(
checkpoints
))
logger
.
info
(
"resume from {}"
.
format
(
checkpoints
))
elif
pretrained_model
:
elif
pretrained_model
:
if
not
isinstance
(
pretrained_model
,
list
):
load_pretrained_params
(
model
,
pretrained_model
)
pretrained_model
=
[
pretrained_model
]
for
pretrained
in
pretrained_model
:
if
not
(
os
.
path
.
isdir
(
pretrained
)
or
os
.
path
.
exists
(
pretrained
+
'.pdparams'
)):
raise
ValueError
(
"Model pretrain path {} does not "
"exists."
.
format
(
pretrained
))
param_state_dict
=
paddle
.
load
(
pretrained
+
'.pdparams'
)
model
.
set_state_dict
(
param_state_dict
)
logger
.
info
(
"load pretrained model from {}"
.
format
(
pretrained_model
))
else
:
else
:
logger
.
info
(
'train from scratch'
)
logger
.
info
(
'train from scratch'
)
return
best_model_dict
return
best_model_dict
def
load_dygraph_params
(
config
,
model
,
logger
,
optimizer
):
ckp
=
config
[
'Global'
][
'checkpoints'
]
if
ckp
and
os
.
path
.
exists
(
ckp
+
".pdparams"
):
pre_best_model_dict
=
init_model
(
config
,
model
,
optimizer
)
return
pre_best_model_dict
else
:
pm
=
config
[
'Global'
][
'pretrained_model'
]
if
pm
is
None
:
return
{}
if
not
os
.
path
.
exists
(
pm
)
and
not
os
.
path
.
exists
(
pm
+
".pdparams"
):
logger
.
info
(
f
"The pretrained_model
{
pm
}
does not exists!"
)
return
{}
pm
=
pm
if
pm
.
endswith
(
'.pdparams'
)
else
pm
+
'.pdparams'
params
=
paddle
.
load
(
pm
)
state_dict
=
model
.
state_dict
()
new_state_dict
=
{}
for
k1
,
k2
in
zip
(
state_dict
.
keys
(),
params
.
keys
()):
if
list
(
state_dict
[
k1
].
shape
)
==
list
(
params
[
k2
].
shape
):
new_state_dict
[
k1
]
=
params
[
k2
]
else
:
logger
.
info
(
f
"The shape of model params
{
k1
}
{
state_dict
[
k1
].
shape
}
not matched with loaded params
{
k2
}
{
params
[
k2
].
shape
}
!"
)
model
.
set_state_dict
(
new_state_dict
)
logger
.
info
(
f
"loaded pretrained_model successful from
{
pm
}
"
)
return
{}
def
load_pretrained_params
(
model
,
path
):
def
load_pretrained_params
(
model
,
path
):
if
path
is
None
:
logger
=
get_logger
()
return
False
if
path
.
endswith
(
'pdparams'
):
if
not
os
.
path
.
exists
(
path
)
and
not
os
.
path
.
exists
(
path
+
".pdparams"
):
path
=
path
.
replace
(
'.pdparams'
,
''
)
print
(
f
"The pretrained_model
{
path
}
does not exists!"
)
assert
os
.
path
.
exists
(
path
+
".pdparams"
),
\
return
False
f
"The
{
path
}
.pdparams does not exists!"
path
=
path
if
path
.
endswith
(
'.pdparams'
)
else
path
+
'.pdparams'
params
=
paddle
.
load
(
path
+
'.pdparams'
)
params
=
paddle
.
load
(
path
)
state_dict
=
model
.
state_dict
()
state_dict
=
model
.
state_dict
()
new_state_dict
=
{}
new_state_dict
=
{}
for
k1
,
k2
in
zip
(
state_dict
.
keys
(),
params
.
keys
()):
for
k1
,
k2
in
zip
(
state_dict
.
keys
(),
params
.
keys
()):
if
list
(
state_dict
[
k1
].
shape
)
==
list
(
params
[
k2
].
shape
):
if
list
(
state_dict
[
k1
].
shape
)
==
list
(
params
[
k2
].
shape
):
new_state_dict
[
k1
]
=
params
[
k2
]
new_state_dict
[
k1
]
=
params
[
k2
]
else
:
else
:
print
(
logger
.
info
(
f
"The shape of model params
{
k1
}
{
state_dict
[
k1
].
shape
}
not matched with loaded params
{
k2
}
{
params
[
k2
].
shape
}
!"
f
"The shape of model params
{
k1
}
{
state_dict
[
k1
].
shape
}
not matched with loaded params
{
k2
}
{
params
[
k2
].
shape
}
!"
)
)
model
.
set_state_dict
(
new_state_dict
)
model
.
set_state_dict
(
new_state_dict
)
print
(
f
"load pretrain successful from
{
path
}
"
)
logger
.
info
(
f
"load pretrain successful from
{
path
}
"
)
return
model
return
model
...
...
tools/eval.py
浏览文件 @
ae4167dc
...
@@ -27,7 +27,7 @@ from ppocr.data import build_dataloader
...
@@ -27,7 +27,7 @@ from ppocr.data import build_dataloader
from
ppocr.modeling.architectures
import
build_model
from
ppocr.modeling.architectures
import
build_model
from
ppocr.postprocess
import
build_post_process
from
ppocr.postprocess
import
build_post_process
from
ppocr.metrics
import
build_metric
from
ppocr.metrics
import
build_metric
from
ppocr.utils.save_load
import
init_model
,
load_dygraph_params
from
ppocr.utils.save_load
import
load_model
from
ppocr.utils.utility
import
print_dict
from
ppocr.utils.utility
import
print_dict
import
tools.program
as
program
import
tools.program
as
program
...
@@ -60,7 +60,7 @@ def main():
...
@@ -60,7 +60,7 @@ def main():
else
:
else
:
model_type
=
None
model_type
=
None
best_model_dict
=
load_
dygraph_params
(
config
,
model
,
logger
,
None
)
best_model_dict
=
load_
model
(
config
,
model
)
if
len
(
best_model_dict
):
if
len
(
best_model_dict
):
logger
.
info
(
'metric in ckpt ***************'
)
logger
.
info
(
'metric in ckpt ***************'
)
for
k
,
v
in
best_model_dict
.
items
():
for
k
,
v
in
best_model_dict
.
items
():
...
...
tools/export_center.py
浏览文件 @
ae4167dc
...
@@ -27,7 +27,7 @@ sys.path.append(os.path.abspath(os.path.join(__dir__, '..')))
...
@@ -27,7 +27,7 @@ sys.path.append(os.path.abspath(os.path.join(__dir__, '..')))
from
ppocr.data
import
build_dataloader
from
ppocr.data
import
build_dataloader
from
ppocr.modeling.architectures
import
build_model
from
ppocr.modeling.architectures
import
build_model
from
ppocr.postprocess
import
build_post_process
from
ppocr.postprocess
import
build_post_process
from
ppocr.utils.save_load
import
init_model
,
load_dygraph_params
from
ppocr.utils.save_load
import
load_model
from
ppocr.utils.utility
import
print_dict
from
ppocr.utils.utility
import
print_dict
import
tools.program
as
program
import
tools.program
as
program
...
@@ -57,7 +57,7 @@ def main():
...
@@ -57,7 +57,7 @@ def main():
model
=
build_model
(
config
[
'Architecture'
])
model
=
build_model
(
config
[
'Architecture'
])
best_model_dict
=
load_
dygraph_params
(
config
,
model
,
logger
,
None
)
best_model_dict
=
load_
model
(
config
,
model
)
if
len
(
best_model_dict
):
if
len
(
best_model_dict
):
logger
.
info
(
'metric in ckpt ***************'
)
logger
.
info
(
'metric in ckpt ***************'
)
for
k
,
v
in
best_model_dict
.
items
():
for
k
,
v
in
best_model_dict
.
items
():
...
...
tools/export_model.py
浏览文件 @
ae4167dc
...
@@ -26,7 +26,7 @@ from paddle.jit import to_static
...
@@ -26,7 +26,7 @@ from paddle.jit import to_static
from
ppocr.modeling.architectures
import
build_model
from
ppocr.modeling.architectures
import
build_model
from
ppocr.postprocess
import
build_post_process
from
ppocr.postprocess
import
build_post_process
from
ppocr.utils.save_load
import
init
_model
from
ppocr.utils.save_load
import
load
_model
from
ppocr.utils.logging
import
get_logger
from
ppocr.utils.logging
import
get_logger
from
tools.program
import
load_config
,
merge_config
,
ArgsParser
from
tools.program
import
load_config
,
merge_config
,
ArgsParser
...
@@ -107,7 +107,7 @@ def main():
...
@@ -107,7 +107,7 @@ def main():
else
:
# base rec model
else
:
# base rec model
config
[
"Architecture"
][
"Head"
][
"out_channels"
]
=
char_num
config
[
"Architecture"
][
"Head"
][
"out_channels"
]
=
char_num
model
=
build_model
(
config
[
"Architecture"
])
model
=
build_model
(
config
[
"Architecture"
])
init
_model
(
config
,
model
)
load
_model
(
config
,
model
)
model
.
eval
()
model
.
eval
()
save_path
=
config
[
"Global"
][
"save_inference_dir"
]
save_path
=
config
[
"Global"
][
"save_inference_dir"
]
...
...
tools/infer_cls.py
浏览文件 @
ae4167dc
...
@@ -32,7 +32,7 @@ import paddle
...
@@ -32,7 +32,7 @@ import paddle
from
ppocr.data
import
create_operators
,
transform
from
ppocr.data
import
create_operators
,
transform
from
ppocr.modeling.architectures
import
build_model
from
ppocr.modeling.architectures
import
build_model
from
ppocr.postprocess
import
build_post_process
from
ppocr.postprocess
import
build_post_process
from
ppocr.utils.save_load
import
init
_model
from
ppocr.utils.save_load
import
load
_model
from
ppocr.utils.utility
import
get_image_file_list
from
ppocr.utils.utility
import
get_image_file_list
import
tools.program
as
program
import
tools.program
as
program
...
@@ -47,7 +47,7 @@ def main():
...
@@ -47,7 +47,7 @@ def main():
# build model
# build model
model
=
build_model
(
config
[
'Architecture'
])
model
=
build_model
(
config
[
'Architecture'
])
init
_model
(
config
,
model
)
load
_model
(
config
,
model
)
# create data ops
# create data ops
transforms
=
[]
transforms
=
[]
...
...
tools/infer_det.py
浏览文件 @
ae4167dc
...
@@ -34,7 +34,7 @@ import paddle
...
@@ -34,7 +34,7 @@ import paddle
from
ppocr.data
import
create_operators
,
transform
from
ppocr.data
import
create_operators
,
transform
from
ppocr.modeling.architectures
import
build_model
from
ppocr.modeling.architectures
import
build_model
from
ppocr.postprocess
import
build_post_process
from
ppocr.postprocess
import
build_post_process
from
ppocr.utils.save_load
import
init_model
,
load_dygraph_params
from
ppocr.utils.save_load
import
load_model
from
ppocr.utils.utility
import
get_image_file_list
from
ppocr.utils.utility
import
get_image_file_list
import
tools.program
as
program
import
tools.program
as
program
...
@@ -59,7 +59,7 @@ def main():
...
@@ -59,7 +59,7 @@ def main():
# build model
# build model
model
=
build_model
(
config
[
'Architecture'
])
model
=
build_model
(
config
[
'Architecture'
])
_
=
load_dygraph_params
(
config
,
model
,
logger
,
None
)
load_model
(
config
,
model
)
# build post process
# build post process
post_process_class
=
build_post_process
(
config
[
'PostProcess'
])
post_process_class
=
build_post_process
(
config
[
'PostProcess'
])
...
...
tools/infer_e2e.py
浏览文件 @
ae4167dc
...
@@ -34,7 +34,7 @@ import paddle
...
@@ -34,7 +34,7 @@ import paddle
from
ppocr.data
import
create_operators
,
transform
from
ppocr.data
import
create_operators
,
transform
from
ppocr.modeling.architectures
import
build_model
from
ppocr.modeling.architectures
import
build_model
from
ppocr.postprocess
import
build_post_process
from
ppocr.postprocess
import
build_post_process
from
ppocr.utils.save_load
import
init
_model
from
ppocr.utils.save_load
import
load
_model
from
ppocr.utils.utility
import
get_image_file_list
from
ppocr.utils.utility
import
get_image_file_list
import
tools.program
as
program
import
tools.program
as
program
...
@@ -68,7 +68,7 @@ def main():
...
@@ -68,7 +68,7 @@ def main():
# build model
# build model
model
=
build_model
(
config
[
'Architecture'
])
model
=
build_model
(
config
[
'Architecture'
])
init
_model
(
config
,
model
)
load
_model
(
config
,
model
)
# build post process
# build post process
post_process_class
=
build_post_process
(
config
[
'PostProcess'
],
post_process_class
=
build_post_process
(
config
[
'PostProcess'
],
...
...
tools/infer_rec.py
浏览文件 @
ae4167dc
...
@@ -33,7 +33,7 @@ import paddle
...
@@ -33,7 +33,7 @@ import paddle
from
ppocr.data
import
create_operators
,
transform
from
ppocr.data
import
create_operators
,
transform
from
ppocr.modeling.architectures
import
build_model
from
ppocr.modeling.architectures
import
build_model
from
ppocr.postprocess
import
build_post_process
from
ppocr.postprocess
import
build_post_process
from
ppocr.utils.save_load
import
init
_model
from
ppocr.utils.save_load
import
load
_model
from
ppocr.utils.utility
import
get_image_file_list
from
ppocr.utils.utility
import
get_image_file_list
import
tools.program
as
program
import
tools.program
as
program
...
@@ -58,7 +58,7 @@ def main():
...
@@ -58,7 +58,7 @@ def main():
model
=
build_model
(
config
[
'Architecture'
])
model
=
build_model
(
config
[
'Architecture'
])
init
_model
(
config
,
model
)
load
_model
(
config
,
model
)
# create data ops
# create data ops
transforms
=
[]
transforms
=
[]
...
@@ -75,9 +75,7 @@ def main():
...
@@ -75,9 +75,7 @@ def main():
'gsrm_slf_attn_bias1'
,
'gsrm_slf_attn_bias2'
'gsrm_slf_attn_bias1'
,
'gsrm_slf_attn_bias2'
]
]
elif
config
[
'Architecture'
][
'algorithm'
]
==
"SAR"
:
elif
config
[
'Architecture'
][
'algorithm'
]
==
"SAR"
:
op
[
op_name
][
'keep_keys'
]
=
[
op
[
op_name
][
'keep_keys'
]
=
[
'image'
,
'valid_ratio'
]
'image'
,
'valid_ratio'
]
else
:
else
:
op
[
op_name
][
'keep_keys'
]
=
[
'image'
]
op
[
op_name
][
'keep_keys'
]
=
[
'image'
]
transforms
.
append
(
op
)
transforms
.
append
(
op
)
...
...
tools/infer_table.py
浏览文件 @
ae4167dc
...
@@ -34,11 +34,12 @@ from paddle.jit import to_static
...
@@ -34,11 +34,12 @@ from paddle.jit import to_static
from
ppocr.data
import
create_operators
,
transform
from
ppocr.data
import
create_operators
,
transform
from
ppocr.modeling.architectures
import
build_model
from
ppocr.modeling.architectures
import
build_model
from
ppocr.postprocess
import
build_post_process
from
ppocr.postprocess
import
build_post_process
from
ppocr.utils.save_load
import
init
_model
from
ppocr.utils.save_load
import
load
_model
from
ppocr.utils.utility
import
get_image_file_list
from
ppocr.utils.utility
import
get_image_file_list
import
tools.program
as
program
import
tools.program
as
program
import
cv2
import
cv2
def
main
(
config
,
device
,
logger
,
vdl_writer
):
def
main
(
config
,
device
,
logger
,
vdl_writer
):
global_config
=
config
[
'Global'
]
global_config
=
config
[
'Global'
]
...
@@ -53,7 +54,7 @@ def main(config, device, logger, vdl_writer):
...
@@ -53,7 +54,7 @@ def main(config, device, logger, vdl_writer):
model
=
build_model
(
config
[
'Architecture'
])
model
=
build_model
(
config
[
'Architecture'
])
init_model
(
config
,
model
,
logger
)
load_model
(
config
,
model
)
# create data ops
# create data ops
transforms
=
[]
transforms
=
[]
...
@@ -104,4 +105,3 @@ def main(config, device, logger, vdl_writer):
...
@@ -104,4 +105,3 @@ def main(config, device, logger, vdl_writer):
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
config
,
device
,
logger
,
vdl_writer
=
program
.
preprocess
()
config
,
device
,
logger
,
vdl_writer
=
program
.
preprocess
()
main
(
config
,
device
,
logger
,
vdl_writer
)
main
(
config
,
device
,
logger
,
vdl_writer
)
tools/train.py
浏览文件 @
ae4167dc
...
@@ -35,7 +35,7 @@ from ppocr.losses import build_loss
...
@@ -35,7 +35,7 @@ from ppocr.losses import build_loss
from
ppocr.optimizer
import
build_optimizer
from
ppocr.optimizer
import
build_optimizer
from
ppocr.postprocess
import
build_post_process
from
ppocr.postprocess
import
build_post_process
from
ppocr.metrics
import
build_metric
from
ppocr.metrics
import
build_metric
from
ppocr.utils.save_load
import
init_model
,
load_dygraph_params
from
ppocr.utils.save_load
import
load_model
import
tools.program
as
program
import
tools.program
as
program
dist
.
get_world_size
()
dist
.
get_world_size
()
...
@@ -97,7 +97,7 @@ def main(config, device, logger, vdl_writer):
...
@@ -97,7 +97,7 @@ def main(config, device, logger, vdl_writer):
# build metric
# build metric
eval_class
=
build_metric
(
config
[
'Metric'
])
eval_class
=
build_metric
(
config
[
'Metric'
])
# load pretrain model
# load pretrain model
pre_best_model_dict
=
load_
dygraph_params
(
config
,
model
,
logger
,
optimizer
)
pre_best_model_dict
=
load_
model
(
config
,
model
,
optimizer
)
logger
.
info
(
'train dataloader has {} iters'
.
format
(
len
(
train_dataloader
)))
logger
.
info
(
'train dataloader has {} iters'
.
format
(
len
(
train_dataloader
)))
if
valid_dataloader
is
not
None
:
if
valid_dataloader
is
not
None
:
logger
.
info
(
'valid dataloader has {} iters'
.
format
(
logger
.
info
(
'valid dataloader has {} iters'
.
format
(
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录