提交 09fc5014 编写于 作者: B breezedeus

feat: add command `export-onnx` to export onnx models

上级 ba6fbe76
......@@ -30,10 +30,22 @@ from pathlib import Path
import click
import Levenshtein
from torchvision import transforms as T
import torch
from cnocr.consts import MODEL_VERSION, ENCODER_CONFIGS, DECODER_CONFIGS
from cnocr.utils import set_logger, load_model_params, check_model_name, save_img, read_img
from cnocr.data_utils.aug import NormalizeAug, RandomPaddingAug, RandomStretchAug, RandomCrop
from cnocr.utils import (
set_logger,
load_model_params,
check_model_name,
save_img,
read_img,
)
from cnocr.data_utils.aug import (
NormalizeAug,
RandomPaddingAug,
RandomStretchAug,
RandomCrop,
)
from cnocr.dataset import OcrDataModule
from cnocr.trainer import PlTrainer, resave_model
from cnocr import CnOcr, gen_model
......@@ -60,7 +72,7 @@ def cli():
'--model-name',
type=str,
default=DEFAULT_MODEL_NAME,
help='模型名称。默认值为 %s' % DEFAULT_MODEL_NAME,
help='模型名称。默认值为 `%s`' % DEFAULT_MODEL_NAME,
)
@click.option(
'-i',
......@@ -80,19 +92,20 @@ def cli():
'--resume-from-checkpoint',
type=str,
default=None,
help='恢复此前中断的训练状态,继续训练。默认为 `None`',
help='恢复此前中断的训练状态,继续训练。所以文件中应该包含训练状态。默认为 `None`',
)
@click.option(
'-p',
'--pretrained-model-fp',
type=str,
default=None,
help='导入的训练好的模型,作为初始模型。'
'优先级低于"--restore-training-fp",当传入"--restore-training-fp"时,此传入失效。默认为 `None`',
help='导入的训练好的模型,作为模型初始值。'
'优先级低于"--resume-from-checkpoint",当传入"--resume-from-checkpoint"时,此传入失效。默认为 `None`',
)
def train(
model_name, index_dir, train_config_fp, resume_from_checkpoint, pretrained_model_fp
):
"""训练模型"""
check_model_name(model_name)
train_transform = T.Compose(
[
......@@ -187,6 +200,7 @@ def visualize_example(example, fp_prefix):
help="是否输入图片只包含单行文字。对包含单行文字的图片,不做按行切分;否则会先对图片按行分割后再进行识别",
)
def predict(model_name, pretrained_model_fp, context, img_file_or_dir, single_line):
"""模型预测"""
ocr = CnOcr(model_name=model_name, model_fp=pretrained_model_fp, context=context)
ocr_func = ocr.ocr_for_single_line if single_line else ocr.ocr
fp_list = []
......@@ -260,6 +274,7 @@ def evaluate(
output_dir,
verbose,
):
"""评估模型效果"""
ocr = CnOcr(model_name=model_name, model_fp=pretrained_model_fp, context=context)
fn_labels_list = read_input_file(eval_index_fp)
......@@ -371,5 +386,67 @@ def resave_model_file(
resave_model(input_model_fp, output_model_fp, map_location='cpu')
def export_to_onnx(model_name, output_model_fp, input_model_fp=None):
import onnx
ocr = CnOcr(model_name, model_fp=input_model_fp)
model = ocr._model
x = torch.randn(1, 1, 32, 280)
input_lengths = torch.tensor([280])
model.postprocessor = None # 这个无法ONNX化
symbolic_names = {0: 'batch_size', 3: 'width'}
with torch.no_grad():
model.eval()
torch.onnx.export(
model,
args=(x, input_lengths),
f=output_model_fp,
export_params=True,
# opset_version=10,
do_constant_folding=True,
input_names=['x', 'input_lengths'],
output_names=['logits', 'output_lengths'],
dynamic_axes={
'x': symbolic_names, # variable length axes
'input_lengths': {0: 'batch_size'},
'logits': {0: 'batch_size'},
},
)
onnx_model = onnx.load(output_model_fp)
onnx.checker.check_model(onnx_model)
logger.info('model is exported to %s' % output_model_fp)
@cli.command('export-onnx')
@click.option(
'-m',
'--model-name',
type=str,
default=DEFAULT_MODEL_NAME,
help='模型名称。默认值为 `%s`' % DEFAULT_MODEL_NAME,
)
@click.option(
'-i',
'--input-model-fp',
type=str,
default=None,
help='输入的模型文件路径。 默认为 `None`,表示使用系统自带的预训练模型',
)
@click.option(
'-o', '--output-model-fp', type=str, required=True, help='输出的模型文件路径(.onnx)'
)
def export_onnx_model(
model_name, input_model_fp, output_model_fp,
):
"""把训练好的模型导出为 ONNX 格式。
当前无法导出 `*-gru` 模型, 具体说明见:https://discuss.pytorch.org/t/exporting-gru-rnn-to-onnx/27244 ,
后续版本会修复此问题。
"""
export_to_onnx(model_name, output_model_fp, input_model_fp)
if __name__ == "__main__":
cli()
......@@ -242,6 +242,7 @@ def resave_model(module_fp, output_model_fp, map_location=None):
"""PlTrainer存储的文件对应其 `pl_module` 模块,需利用此函数转存为 `model` 对应的模型文件。"""
checkpoint = torch.load(module_fp, map_location=map_location)
state_dict = {}
for k, v in checkpoint['state_dict'].items():
state_dict[k.split('.', maxsplit=1)[1]] = v
if all([k.startswith('model.') for k in checkpoint['state_dict'].keys()]):
for k, v in checkpoint['state_dict'].items():
state_dict[k.split('.', maxsplit=1)[1]] = v
torch.save({'state_dict': state_dict}, output_model_fp)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册