提交 5d338906 编写于 作者: B breezedeus

support onnx models for predictions

上级 2f8a95fe
......@@ -39,9 +39,11 @@ from cnocr.utils import (
load_model_params,
rescale_img,
pad_img_seq,
to_numpy,
)
from .data_utils.aug import NormalizeAug
from .line_split import line_split
from .models.ctc import CTCPostProcessor
logger = logging.getLogger(__name__)
......@@ -62,7 +64,9 @@ class CnOcr(object):
cand_alphabet: Optional[Union[Collection, str]] = None,
context: str = 'cpu', # ['cpu', 'gpu', 'cuda']
model_fp: Optional[str] = None,
model_backend: str = 'onnx', # ['pytorch', 'onnx']
root: Union[str, Path] = data_dir(),
vocab_fp: Union[str, Path] = VOCAB_FP,
**kwargs,
):
"""
......@@ -73,9 +77,11 @@ class CnOcr(object):
cand_alphabet (Optional[Union[Collection, str]]): 待识别字符所在的候选集合。默认为 `None`,表示不限定识别字符范围
context (str): 'cpu', or 'gpu'。表明预测时是使用CPU还是GPU。默认为 `cpu`
model_fp (Optional[str]): 如果不使用系统自带的模型,可以通过此参数直接指定所使用的模型文件('.ckpt' 文件)
model_backend (str): 'pytorch', or 'onnx'。表明预测时是使用是使用 PyTorch 模型,还是使用 ONNX 模型。默认为 `pytorch`
root (Union[str, Path]): 模型文件所在的根目录。
Linux/Mac下默认值为 `~/.cnocr`,表示模型文件所处文件夹类似 `~/.cnocr/2.1/densenet_lite_136-fc`。
Windows下默认值为 `C:/Users/<username>/AppData/Roaming/cnocr`。
vocab_fp (Union[str, Path]): 字符集合的文件路径,即 `label_cn.txt` 文件路径
**kwargs: 目前未被使用。
Examples:
......@@ -89,6 +95,8 @@ class CnOcr(object):
>>> ocr = CnOcr(model_name='densenet_lite_136-fc', cand_alphabet='0123456789')
"""
model_backend = model_backend.lower()
assert model_backend in ('pytorch', 'onnx')
if 'name' in kwargs:
logger.warning(
'param `name` is useless and deprecated since version %s'
......@@ -96,22 +104,31 @@ class CnOcr(object):
)
check_model_name(model_name)
check_context(context)
self._model_name = model_name
self._model_backend = model_backend
if context == 'gpu':
context = 'cuda'
self.context = context
self._model_file_prefix = '{}-{}'.format(self.MODEL_FILE_PREFIX, model_name)
model_epoch = AVAILABLE_MODELS.get(model_name, [None])[0]
if model_epoch is not None:
self._model_file_prefix = '%s-epoch=%03d' % (
self._model_file_prefix,
model_epoch,
try:
self._assert_and_prepare_model_files(model_fp, root)
except NotImplementedError:
logger.warning(
'no available model is found for name %s and backend %s'
% (self._model_name, self._model_backend)
)
self._model_backend = (
'onnx' if self._model_backend == 'pytorch' else 'pytorch'
)
logger.warning(
'trying to use name %s and backend %s'
% (self._model_name, self._model_backend)
)
self._assert_and_prepare_model_files(model_fp, root)
self._assert_and_prepare_model_files(model_fp, root)
self._vocab, self._letter2id = read_charset(VOCAB_FP)
self._vocab, self._letter2id = read_charset(vocab_fp)
self.postprocessor = CTCPostProcessor(vocab=self._vocab)
self._candidates = None
self.set_cand_alphabet(cand_alphabet)
......@@ -119,6 +136,15 @@ class CnOcr(object):
self._model = self._get_model(context)
def _assert_and_prepare_model_files(self, model_fp, root):
self._model_file_prefix = '{}-{}'.format(self.MODEL_FILE_PREFIX, self._model_name)
model_epoch = AVAILABLE_MODELS.get((self._model_name, self._model_backend), [None])[0]
if model_epoch is not None:
self._model_file_prefix = '%s-epoch=%03d' % (
self._model_file_prefix,
model_epoch,
)
if model_fp is not None and not os.path.isfile(model_fp):
raise FileNotFoundError('can not find model file %s' % model_fp)
......@@ -128,25 +154,37 @@ class CnOcr(object):
root = os.path.join(root, MODEL_VERSION)
self._model_dir = os.path.join(root, self._model_name)
fps = glob('%s/%s*.ckpt' % (self._model_dir, self._model_file_prefix))
model_ext = 'ckpt' if self._model_backend == 'pytorch' else 'onnx'
fps = glob('%s/%s*.%s' % (self._model_dir, self._model_file_prefix, model_ext))
if len(fps) > 1:
raise ValueError(
'multiple ckpt files are found in %s, not sure which one should be used'
% self._model_dir
'multiple %s files are found in %s, not sure which one should be used'
% (model_ext, self._model_dir)
)
elif len(fps) < 1:
logger.warning('no ckpt file is found in %s' % self._model_dir)
get_model_file(self._model_dir) # download the .zip file and unzip
fps = glob('%s/%s*.ckpt' % (self._model_dir, self._model_file_prefix))
logger.warning('no %s file is found in %s' % (model_ext, self._model_dir))
get_model_file(
self._model_name, self._model_backend, self._model_dir
) # download the .zip file and unzip
fps = glob(
'%s/%s*.%s' % (self._model_dir, self._model_file_prefix, model_ext)
)
self._model_fp = fps[0]
def _get_model(self, context):
logger.info('use model: %s' % self._model_fp)
model = gen_model(self._model_name, self._vocab)
model.eval()
model.to(self.context)
model = load_model_params(model, self._model_fp, context)
if self._model_backend == 'pytorch':
model = gen_model(self._model_name, self._vocab)
model.eval()
model.to(self.context)
model = load_model_params(model, self._model_fp, context)
elif self._model_backend == 'onnx':
import onnxruntime
model = onnxruntime.InferenceSession(self._model_fp)
else:
raise NotImplementedError(f'{self._model_backend} is not supported yet')
return model
......@@ -335,11 +373,33 @@ class CnOcr(object):
img = rescale_img(img.transpose((2, 0, 1))) # res: [C, H, W]
return NormalizeAug()(img).to(device=torch.device(self.context))
@torch.no_grad()
def _predict(self, img_list: List[torch.Tensor]):
img_lengths = torch.tensor([img.shape[2] for img in img_list])
imgs = pad_img_seq(img_list)
out = self._model(
imgs, img_lengths, candidates=self._candidates, return_preds=True
if self._model_backend == 'pytorch':
with torch.no_grad():
out = self._model(
imgs, img_lengths, candidates=self._candidates, return_preds=True
)
else: # onnx
out = self._onnx_predict(imgs, img_lengths)
return out
def _onnx_predict(self, imgs, img_lengths):
ort_session = self._model
ort_inputs = {
ort_session.get_inputs()[0].name: to_numpy(imgs),
ort_session.get_inputs()[1].name: to_numpy(img_lengths),
}
ort_outs = ort_session.run(None, ort_inputs)
out = {
'logits': torch.from_numpy(ort_outs[0]),
'output_lengths': torch.from_numpy(ort_outs[1]),
}
out['logits'] = OcrModel.mask_by_candidates(
out['logits'], self._candidates, self._vocab, self._letter2id
)
out["preds"] = self.postprocessor(out['logits'], out['output_lengths'])
return out
......@@ -107,12 +107,16 @@ root_url = (
)
# name: (epoch, url)
AVAILABLE_MODELS = {
'densenet_lite_114-fc': (37, root_url + 'densenet_lite_114-fc.zip'),
'densenet_lite_124-fc': (39, root_url + 'densenet_lite_124-fc.zip'),
'densenet_lite_134-fc': (34, root_url + 'densenet_lite_134-fc.zip'),
'densenet_lite_136-fc': (39, root_url + 'densenet_lite_136-fc.zip'),
'densenet_lite_134-gru': (2, root_url + 'densenet_lite_134-gru.zip'),
'densenet_lite_136-gru': (2, root_url + 'densenet_lite_136-gru.zip'),
('densenet_lite_114-fc', 'pytorch'): (37, root_url + 'densenet_lite_114-fc.zip'),
('densenet_lite_124-fc', 'pytorch'): (39, root_url + 'densenet_lite_124-fc.zip'),
('densenet_lite_134-fc', 'pytorch'): (34, root_url + 'densenet_lite_134-fc.zip'),
('densenet_lite_136-fc', 'pytorch'): (39, root_url + 'densenet_lite_136-fc.zip'),
('densenet_lite_114-fc', 'onnx'): (37, root_url + 'densenet_lite_114-fc-onnx.zip'),
('densenet_lite_124-fc', 'onnx'): (39, root_url + 'densenet_lite_124-fc-onnx.zip'),
('densenet_lite_134-fc', 'onnx'): (34, root_url + 'densenet_lite_134-fc-onnx.zip'),
('densenet_lite_136-fc', 'onnx'): (39, root_url + 'densenet_lite_136-fc-onnx.zip'),
('densenet_lite_134-gru', 'pytorch'): (2, root_url + 'densenet_lite_134-gru.zip'),
('densenet_lite_136-gru', 'pytorch'): (2, root_url + 'densenet_lite_136-gru.zip'),
}
# 候选字符集合
......
......@@ -107,6 +107,12 @@ def check_model_name(model_name):
assert decoder_type in DECODER_CONFIGS
def to_numpy(tensor: torch.Tensor) -> np.ndarray:
return (
tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()
)
def check_sha1(filename, sha1_hash):
"""Check whether the sha1 hash of the file content matches the expected hash.
Parameters
......@@ -202,7 +208,7 @@ def download(url, path=None, overwrite=False, sha1_hash=None):
return fname
def get_model_file(model_dir):
def get_model_file(model_name, model_backend, model_dir):
r"""Return location for the downloaded models on local file system.
This function will download from online model zoo when model cannot be found or has mismatch.
......@@ -210,6 +216,8 @@ def get_model_file(model_dir):
Parameters
----------
model_name : str
model_backend : str
model_dir : str, default $CNOCR_HOME
Location for keeping the model parameters.
......@@ -222,14 +230,12 @@ def get_model_file(model_dir):
par_dir = os.path.dirname(model_dir)
os.makedirs(par_dir, exist_ok=True)
zip_file_path = model_dir + '.zip'
if (model_name, model_backend) not in AVAILABLE_MODELS:
raise NotImplementedError('%s is not a downloadable model' % model_name)
url = AVAILABLE_MODELS[(model_name, model_backend)][1]
zip_file_path = os.path.join(par_dir, os.path.basename(url))
if not os.path.exists(zip_file_path):
model_name = os.path.basename(model_dir)
if model_name not in AVAILABLE_MODELS:
raise NotImplementedError(
'%s is not an available downloaded model' % model_name
)
url = AVAILABLE_MODELS[model_name][1]
download(url, path=zip_file_path, overwrite=True)
with zipfile.ZipFile(zip_file_path) as zf:
zf.extractall(par_dir)
......
......@@ -19,7 +19,10 @@
import os
import sys
import logging
import time
import pytest
import numpy as np
from PIL import Image
import Levenshtein
......@@ -28,13 +31,15 @@ sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.insert(1, os.path.dirname(os.path.abspath(__file__)))
from cnocr import CnOcr
from cnocr.utils import read_img
from cnocr.utils import set_logger, read_img
from cnocr.consts import NUMBERS, AVAILABLE_MODELS
from cnocr.line_split import line_split
logger = set_logger(log_level=logging.INFO)
root_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
example_dir = os.path.join(root_dir, 'docs/examples')
CNOCR = CnOcr(model_name='densenet-s-fc', model_epoch=None)
CNOCR = CnOcr(model_name='densenet_lite_136-fc', model_epoch=None)
SINGLE_LINE_CASES = [
('20457890_2399557098.jpg', ['就会哈哈大笑。3.0']),
......@@ -110,8 +115,7 @@ def cal_score(preds, expected):
@pytest.mark.parametrize('img_fp, expected', CASES)
def test_ocr(img_fp, expected):
ocr = CNOCR
root_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
img_fp = os.path.join(root_dir, 'examples', img_fp)
img_fp = os.path.join(example_dir, img_fp)
pred = ocr.ocr(img_fp)
print('\n')
......@@ -132,8 +136,7 @@ def test_ocr(img_fp, expected):
@pytest.mark.parametrize('img_fp, expected', SINGLE_LINE_CASES)
def test_ocr_for_single_line(img_fp, expected):
ocr = CNOCR
root_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
img_fp = os.path.join(root_dir, 'examples', img_fp)
img_fp = os.path.join(example_dir, img_fp)
pred = ocr.ocr_for_single_line(img_fp)
print('\n')
print_preds([pred])
......@@ -165,8 +168,7 @@ def test_ocr_for_single_line(img_fp, expected):
@pytest.mark.parametrize('img_fp, expected', MULTIPLE_LINE_CASES)
def test_ocr_for_single_lines(img_fp, expected):
ocr = CNOCR
root_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
img_fp = os.path.join(root_dir, 'examples', img_fp)
img_fp = os.path.join(example_dir, img_fp)
img = read_img(img_fp)
if img.mean() < 145: # 把黑底白字的图片对调为白底黑字
img = 255 - img
......@@ -186,26 +188,37 @@ def test_ocr_for_single_lines(img_fp, expected):
def test_cand_alphabet():
img_fp = os.path.join(example_dir, 'hybrid.png')
ocr = CnOcr(cand_alphabet=NUMBERS)
pred = ocr.ocr(img_fp)
pred = [''.join(line_p) for line_p, _ in pred]
ocr = CnOcr('densenet_lite_136-fc', cand_alphabet=NUMBERS)
pt_pred = ocr.ocr(img_fp)
pred = [''.join(line_p) for line_p, _ in pt_pred]
print("Predicted Chars:", pred)
assert len(pred) == 1 and pred[0] == '012345678'
ocr = CnOcr('densenet_lite_136-fc', model_backend='onnx', cand_alphabet=NUMBERS)
onnx_pred = ocr.ocr(img_fp)
pred = [''.join(line_p) for line_p, _ in onnx_pred]
print("Predicted Chars:", pred)
assert len(pred) == 1 and pred[0] == '012345678'
INSTANCE_ID = 0
assert pt_pred[0][0] == onnx_pred[0][0]
assert abs(pt_pred[0][1] - onnx_pred[0][1]) < 1e-5
@pytest.mark.parametrize('model_name', AVAILABLE_MODELS.keys())
def test_multiple_instances(model_name):
global INSTANCE_ID
print('test multiple instances for model_name: %s' % model_name)
img_fp = os.path.join(example_dir, 'hybrid.png')
INSTANCE_ID += 1
print('instance id: %d' % INSTANCE_ID)
cnocr1 = CnOcr(model_name, name='instance-%d' % INSTANCE_ID)
print_preds(cnocr1.ocr(img_fp))
INSTANCE_ID += 1
print('instance id: %d' % INSTANCE_ID)
cnocr2 = CnOcr(model_name, name='instance-%d' % INSTANCE_ID, cand_alphabet=NUMBERS)
print_preds(cnocr2.ocr(img_fp))
@pytest.mark.parametrize('img_fp, expected', SINGLE_LINE_CASES)
def test_onnx(img_fp, expected):
img_fp = os.path.join(example_dir, img_fp)
pt_ocr = CnOcr('densenet_lite_136-fc', model_backend='pytorch')
start_time = time.time()
pt_preds = pt_ocr.ocr_for_single_line(img_fp)
end_time = time.time()
print(f'\npytorch time cost {end_time - start_time}', pt_preds)
onnx_ocr = CnOcr('densenet_lite_136-fc', model_backend='onnx')
start_time = time.time()
onnx_preds = onnx_ocr.ocr_for_single_line(img_fp)
end_time = time.time()
print(f'onnx time cost {end_time - start_time}', onnx_preds, '\n\n')
assert pt_preds[0] == onnx_preds[0]
assert abs(pt_preds[1] - onnx_preds[1]) < 1e-5
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册