Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Just_Paranoid
CnOCR
提交
09fc5014
CnOCR
项目概览
Just_Paranoid
/
CnOCR
与 Fork 源项目一致
Fork自
Cloud IDE / CnOCR
通知
1
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
CnOCR
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
前往新版Gitcode,体验更适合开发者的 AI 搜索 >>
提交
09fc5014
编写于
5月 14, 2022
作者:
B
breezedeus
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat: add command `export-onnx` to export onnx models
上级
ba6fbe76
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
86 addition
and
8 deletion
+86
-8
cnocr/cli.py
cnocr/cli.py
+83
-6
cnocr/trainer.py
cnocr/trainer.py
+3
-2
未找到文件。
cnocr/cli.py
浏览文件 @
09fc5014
...
...
@@ -30,10 +30,22 @@ from pathlib import Path
import
click
import
Levenshtein
from
torchvision
import
transforms
as
T
import
torch
from
cnocr.consts
import
MODEL_VERSION
,
ENCODER_CONFIGS
,
DECODER_CONFIGS
from
cnocr.utils
import
set_logger
,
load_model_params
,
check_model_name
,
save_img
,
read_img
from
cnocr.data_utils.aug
import
NormalizeAug
,
RandomPaddingAug
,
RandomStretchAug
,
RandomCrop
from
cnocr.utils
import
(
set_logger
,
load_model_params
,
check_model_name
,
save_img
,
read_img
,
)
from
cnocr.data_utils.aug
import
(
NormalizeAug
,
RandomPaddingAug
,
RandomStretchAug
,
RandomCrop
,
)
from
cnocr.dataset
import
OcrDataModule
from
cnocr.trainer
import
PlTrainer
,
resave_model
from
cnocr
import
CnOcr
,
gen_model
...
...
@@ -60,7 +72,7 @@ def cli():
'--model-name'
,
type
=
str
,
default
=
DEFAULT_MODEL_NAME
,
help
=
'模型名称。默认值为
%s
'
%
DEFAULT_MODEL_NAME
,
help
=
'模型名称。默认值为
`%s`
'
%
DEFAULT_MODEL_NAME
,
)
@
click
.
option
(
'-i'
,
...
...
@@ -80,19 +92,20 @@ def cli():
'--resume-from-checkpoint'
,
type
=
str
,
default
=
None
,
help
=
'恢复此前中断的训练状态,继续训练。默认为 `None`'
,
help
=
'恢复此前中断的训练状态,继续训练。
所以文件中应该包含训练状态。
默认为 `None`'
,
)
@
click
.
option
(
'-p'
,
'--pretrained-model-fp'
,
type
=
str
,
default
=
None
,
help
=
'导入的训练好的模型,作为
初始模型
。'
'优先级低于"--res
tore-training-fp",当传入"--restore-training-fp
"时,此传入失效。默认为 `None`'
,
help
=
'导入的训练好的模型,作为
模型初始值
。'
'优先级低于"--res
ume-from-checkpoint",当传入"--resume-from-checkpoint
"时,此传入失效。默认为 `None`'
,
)
def
train
(
model_name
,
index_dir
,
train_config_fp
,
resume_from_checkpoint
,
pretrained_model_fp
):
"""训练模型"""
check_model_name
(
model_name
)
train_transform
=
T
.
Compose
(
[
...
...
@@ -187,6 +200,7 @@ def visualize_example(example, fp_prefix):
help
=
"是否输入图片只包含单行文字。对包含单行文字的图片,不做按行切分;否则会先对图片按行分割后再进行识别"
,
)
def
predict
(
model_name
,
pretrained_model_fp
,
context
,
img_file_or_dir
,
single_line
):
"""模型预测"""
ocr
=
CnOcr
(
model_name
=
model_name
,
model_fp
=
pretrained_model_fp
,
context
=
context
)
ocr_func
=
ocr
.
ocr_for_single_line
if
single_line
else
ocr
.
ocr
fp_list
=
[]
...
...
@@ -260,6 +274,7 @@ def evaluate(
output_dir
,
verbose
,
):
"""评估模型效果"""
ocr
=
CnOcr
(
model_name
=
model_name
,
model_fp
=
pretrained_model_fp
,
context
=
context
)
fn_labels_list
=
read_input_file
(
eval_index_fp
)
...
...
@@ -371,5 +386,67 @@ def resave_model_file(
resave_model
(
input_model_fp
,
output_model_fp
,
map_location
=
'cpu'
)
def
export_to_onnx
(
model_name
,
output_model_fp
,
input_model_fp
=
None
):
import
onnx
ocr
=
CnOcr
(
model_name
,
model_fp
=
input_model_fp
)
model
=
ocr
.
_model
x
=
torch
.
randn
(
1
,
1
,
32
,
280
)
input_lengths
=
torch
.
tensor
([
280
])
model
.
postprocessor
=
None
# 这个无法ONNX化
symbolic_names
=
{
0
:
'batch_size'
,
3
:
'width'
}
with
torch
.
no_grad
():
model
.
eval
()
torch
.
onnx
.
export
(
model
,
args
=
(
x
,
input_lengths
),
f
=
output_model_fp
,
export_params
=
True
,
# opset_version=10,
do_constant_folding
=
True
,
input_names
=
[
'x'
,
'input_lengths'
],
output_names
=
[
'logits'
,
'output_lengths'
],
dynamic_axes
=
{
'x'
:
symbolic_names
,
# variable length axes
'input_lengths'
:
{
0
:
'batch_size'
},
'logits'
:
{
0
:
'batch_size'
},
},
)
onnx_model
=
onnx
.
load
(
output_model_fp
)
onnx
.
checker
.
check_model
(
onnx_model
)
logger
.
info
(
'model is exported to %s'
%
output_model_fp
)
@
cli
.
command
(
'export-onnx'
)
@
click
.
option
(
'-m'
,
'--model-name'
,
type
=
str
,
default
=
DEFAULT_MODEL_NAME
,
help
=
'模型名称。默认值为 `%s`'
%
DEFAULT_MODEL_NAME
,
)
@
click
.
option
(
'-i'
,
'--input-model-fp'
,
type
=
str
,
default
=
None
,
help
=
'输入的模型文件路径。 默认为 `None`,表示使用系统自带的预训练模型'
,
)
@
click
.
option
(
'-o'
,
'--output-model-fp'
,
type
=
str
,
required
=
True
,
help
=
'输出的模型文件路径(.onnx)'
)
def
export_onnx_model
(
model_name
,
input_model_fp
,
output_model_fp
,
):
"""把训练好的模型导出为 ONNX 格式。
当前无法导出 `*-gru` 模型, 具体说明见:https://discuss.pytorch.org/t/exporting-gru-rnn-to-onnx/27244 ,
后续版本会修复此问题。
"""
export_to_onnx
(
model_name
,
output_model_fp
,
input_model_fp
)
if
__name__
==
"__main__"
:
cli
()
cnocr/trainer.py
浏览文件 @
09fc5014
...
...
@@ -242,6 +242,7 @@ def resave_model(module_fp, output_model_fp, map_location=None):
"""PlTrainer存储的文件对应其 `pl_module` 模块,需利用此函数转存为 `model` 对应的模型文件。"""
checkpoint
=
torch
.
load
(
module_fp
,
map_location
=
map_location
)
state_dict
=
{}
for
k
,
v
in
checkpoint
[
'state_dict'
].
items
():
state_dict
[
k
.
split
(
'.'
,
maxsplit
=
1
)[
1
]]
=
v
if
all
([
k
.
startswith
(
'model.'
)
for
k
in
checkpoint
[
'state_dict'
].
keys
()]):
for
k
,
v
in
checkpoint
[
'state_dict'
].
items
():
state_dict
[
k
.
split
(
'.'
,
maxsplit
=
1
)[
1
]]
=
v
torch
.
save
({
'state_dict'
:
state_dict
},
output_model_fp
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录