提交 c91e23e7 编写于 作者: B breezedeus

refactor

上级 63672693
......@@ -4,7 +4,7 @@ English [README](./README_en.md) (`out-dated`).
**cnocr****Python 3** 下的**文字识别****Optical Character Recognition**,简称**OCR**)工具包,支持**中文****英文**的常见字符识别,自带了多个训练好的识别模型,安装后即可直接使用。欢迎扫码加入微信交流群:
![QQ群二维码](https://huggingface.co/datasets/breezedeus/cnocr-wx-qr-code/resolve/main/wx-qr-code.JPG)
![微信群二维码](https://huggingface.co/datasets/breezedeus/cnocr-wx-qr-code/resolve/main/wx-qr-code.JPG)
作者也维护 **知识星球** [**CnOCR/CnSTD私享群**](https://t.zsxq.com/FEYZRJQ) ,欢迎加入。**知识星球私享群**会陆续发布一些CnOCR/CnSTD相关的私有资料,包括**更详细的训练教程****未公开的模型**,使用过程中遇到的难题解答等。本群也会发布OCR/STD相关的最新研究资料。此外,**私享群中作者每月提供两次免费特有数据的训练服务**
......
......@@ -37,7 +37,7 @@ from cnocr.utils import (
check_context,
read_img,
load_model_params,
rescale_img,
resize_img,
pad_img_seq,
to_numpy,
)
......@@ -140,7 +140,7 @@ class CnOcr(object):
def _assert_and_prepare_model_files(self, model_fp, root):
self._model_file_prefix = '{}-{}'.format(self.MODEL_FILE_PREFIX, self._model_name)
model_epoch = AVAILABLE_MODELS.get((self._model_name, self._model_backend), [None])[0]
model_epoch = AVAILABLE_MODELS.get_epoch(self._model_name, self._model_backend)
if model_epoch is not None:
self._model_file_prefix = '%s-epoch=%03d' % (
......@@ -373,7 +373,7 @@ class CnOcr(object):
Returns:
torch.Tensor: with shape (1, height, width)
"""
img = rescale_img(img.transpose((2, 0, 1))) # res: [C, H, W]
img = resize_img(img.transpose((2, 0, 1))) # res: [C, H, W]
return NormalizeAug()(img).to(device=torch.device(self.context))
def _predict(self, img_list: List[torch.Tensor]):
......
......@@ -19,8 +19,14 @@
import string
from pathlib import Path
from typing import Tuple, Set, Dict, Any, Optional, Union
import logging
from copy import deepcopy
from .__version__ import __version__
logger = logging.getLogger(__name__)
# 模型版本只对应到第二层,第三层的改动表示模型兼容。
# 如: __version__ = '2.0.*',对应的 MODEL_VERSION 都是 '2.0'
......@@ -101,23 +107,87 @@ DECODER_CONFIGS = {
'fcfull': {'hidden_size': 256, 'dropout': 0.3,},
}
root_url = (
'https://huggingface.co/breezedeus/cnstd-cnocr-models/resolve/main/models/cnocr/%s/'
% MODEL_VERSION
)
# name: (epoch, url)
AVAILABLE_MODELS = {
('densenet_lite_114-fc', 'pytorch'): (37, root_url + 'densenet_lite_114-fc.zip'),
('densenet_lite_124-fc', 'pytorch'): (39, root_url + 'densenet_lite_124-fc.zip'),
('densenet_lite_134-fc', 'pytorch'): (34, root_url + 'densenet_lite_134-fc.zip'),
('densenet_lite_136-fc', 'pytorch'): (39, root_url + 'densenet_lite_136-fc.zip'),
('densenet_lite_114-fc', 'onnx'): (37, root_url + 'densenet_lite_114-fc-onnx.zip'),
('densenet_lite_124-fc', 'onnx'): (39, root_url + 'densenet_lite_124-fc-onnx.zip'),
('densenet_lite_134-fc', 'onnx'): (34, root_url + 'densenet_lite_134-fc-onnx.zip'),
('densenet_lite_136-fc', 'onnx'): (39, root_url + 'densenet_lite_136-fc-onnx.zip'),
('densenet_lite_134-gru', 'pytorch'): (2, root_url + 'densenet_lite_134-gru.zip'),
('densenet_lite_136-gru', 'pytorch'): (2, root_url + 'densenet_lite_136-gru.zip'),
}
class AvailableModels(object):
ROOT_URL = (
'https://huggingface.co/breezedeus/cnstd-cnocr-models/resolve/main/models/cnocr/%s/'
% MODEL_VERSION
)
CNOCR_SPACE = '__cnocr__'
# name: (epoch, url)
CNOCR_MODELS = {
('densenet_lite_114-fc', 'pytorch'): (37, 'densenet_lite_114-fc.zip'),
('densenet_lite_124-fc', 'pytorch'): (39, 'densenet_lite_124-fc.zip'),
('densenet_lite_134-fc', 'pytorch'): (34, 'densenet_lite_134-fc.zip'),
('densenet_lite_136-fc', 'pytorch'): (39, 'densenet_lite_136-fc.zip'),
('densenet_lite_114-fc', 'onnx'): (37, 'densenet_lite_114-fc-onnx.zip'),
('densenet_lite_124-fc', 'onnx'): (39, 'densenet_lite_124-fc-onnx.zip'),
('densenet_lite_134-fc', 'onnx'): (34, 'densenet_lite_134-fc-onnx.zip'),
('densenet_lite_136-fc', 'onnx'): (39, 'densenet_lite_136-fc-onnx.zip'),
('densenet_lite_134-gru', 'pytorch'): (2, 'densenet_lite_134-gru.zip'),
('densenet_lite_136-gru', 'pytorch'): (2, 'densenet_lite_136-gru.zip'),
}
OUTER_MODELS = {}
def all_models(self) -> Set[Tuple[str, str]]:
return set(self.CNOCR_MODELS.keys()) | set(self.OUTER_MODELS.keys())
def __contains__(self, model_name_backend: Tuple[str, str]) -> bool:
return model_name_backend in self.all_models()
def register_models(self, model_dict: Dict[Tuple[str, str], Any], space: str):
assert not space.startswith('__')
for key, val in model_dict.items():
if key in self.CNOCR_MODELS or key in self.OUTER_MODELS:
logger.warning(
'model %s has already existed, and will be ignored' % key
)
continue
val = deepcopy(val)
val['space'] = space
self.OUTER_MODELS[key] = val
def get_space(self, model_name, model_backend) -> Optional[str]:
if (model_name, model_backend) in self.CNOCR_MODELS:
return self.CNOCR_SPACE
elif (model_name, model_backend) in self.OUTER_MODELS:
return self.OUTER_MODELS[(model_name, model_backend)]['space']
return None
def get_vocab_fp(
self, model_name: str, model_backend: str
) -> Optional[Union[str, Path]]:
if (model_name, model_backend) in self.CNOCR_MODELS:
return VOCAB_FP
elif (model_name, model_backend) in self.OUTER_MODELS:
return self.OUTER_MODELS[(model_name, model_backend)]['vocab_fp']
else:
logger.warning(
'no url is found for model %s' % ((model_name, model_backend),)
)
return None
def get_epoch(self, model_name, model_backend) -> Optional[int]:
if (model_name, model_backend) in self.CNOCR_MODELS:
return self.CNOCR_MODELS[(model_name, model_backend)][0]
return None
def get_url(self, model_name, model_backend) -> Optional[str]:
if (model_name, model_backend) in self.CNOCR_MODELS:
url = self.CNOCR_MODELS[(model_name, model_backend)][1]
elif (model_name, model_backend) in self.OUTER_MODELS:
url = self.OUTER_MODELS[(model_name, model_backend)]['url']
else:
logger.warning(
'no url is found for model %s' % ((model_name, model_backend),)
)
return None
url = self.ROOT_URL + url
return url
AVAILABLE_MODELS = AvailableModels()
# 候选字符集合
NUMBERS = string.digits + string.punctuation
......
......@@ -24,7 +24,7 @@ import pytorch_lightning as pt
import torch
from torch.utils.data import DataLoader, Dataset
from .utils import read_charset, read_tsv_file, read_img, rescale_img, pad_img_seq
from .utils import read_charset, read_tsv_file, read_img, resize_img, pad_img_seq
class OcrDataset(Dataset):
......@@ -41,7 +41,7 @@ class OcrDataset(Dataset):
def __getitem__(self, item):
img_fp = self.img_fp_list[item]
img = read_img(img_fp).transpose((2, 0, 1)) # res: [1, H, W]
img = rescale_img(img)
img = resize_img(img)
if self.mode != 'test':
labels = self.labels_list[item]
......
......@@ -24,7 +24,7 @@ import logging
import platform
import zipfile
import requests
from typing import Union, Any, Tuple, List
from typing import Union, Any, Tuple, List, Optional, Dict
from tqdm import tqdm
from PIL import Image
......@@ -231,8 +231,10 @@ def get_model_file(model_name, model_backend, model_dir):
os.makedirs(par_dir, exist_ok=True)
if (model_name, model_backend) not in AVAILABLE_MODELS:
raise NotImplementedError('%s is not a downloadable model' % model_name)
url = AVAILABLE_MODELS[(model_name, model_backend)][1]
raise NotImplementedError(
'%s is not a downloadable model' % ((model_name, model_backend),)
)
url = AVAILABLE_MODELS.get_url(model_name, model_backend)
zip_file_path = os.path.join(par_dir, os.path.basename(url))
if not os.path.exists(zip_file_path):
......@@ -306,17 +308,29 @@ def save_img(img: Union[Tensor, np.ndarray], path):
# Image.fromarray(img).save(path)
def rescale_img(img: np.ndarray) -> torch.Tensor:
def resize_img(
img: np.ndarray,
target_h_w: Optional[Tuple[int, int]] = None,
return_torch: bool = True,
) -> Union[torch.Tensor, np.ndarray]:
"""
rescale an image tensor with [Channel, Height, Width] to the given height value, and keep the ratio
:param img: np.ndarray; should be [c, height, width]
:param target_h_w: (height, width) of the target image or None
:param return_torch: bool; whether to return a `torch.Tensor` or `np.ndarray`
:return: image tensor with the given height. The resulting dim is [C, height, width]
"""
ori_height, ori_width = img.shape[1:]
ratio = ori_height / IMG_STANDARD_HEIGHT
img = torch.from_numpy(img)
if img.size(1) != IMG_STANDARD_HEIGHT:
img = F.resize(img, [IMG_STANDARD_HEIGHT, int(ori_width / ratio)])
if target_h_w is None:
ratio = ori_height / IMG_STANDARD_HEIGHT
target_h_w = (IMG_STANDARD_HEIGHT, int(ori_width / ratio))
if (ori_height, ori_width) != target_h_w:
img = F.resize(torch.from_numpy(img), target_h_w)
if not return_torch:
img = img.numpy()
elif return_torch:
img = torch.from_numpy(img)
return img
......@@ -374,3 +388,29 @@ def get_model_size(model, only_trainable=False):
if only_trainable:
return sum(p.numel() for p in model.parameters() if p.requires_grad)
return sum(p.numel() for p in model.parameters())
def mask_by_candidates(
logits: np.ndarray,
candidates: Optional[Union[str, List[str]]],
vocab: List[str],
letter2id: Dict[str, int],
ignored_tokens: List[int],
):
if candidates is None:
return logits
_candidates = [letter2id[word] for word in candidates]
_candidates.sort()
_candidates = np.array(_candidates, dtype=int)
candidates = np.zeros((len(vocab),), dtype=bool)
candidates[_candidates] = True
# candidates[-1] = True # for cnocr, 间隔符号/填充符号,必须为真
candidates[ignored_tokens] = True
candidates = np.expand_dims(candidates, axis=(0, 1)) # 1 x 1 x (vocab_size+1)
candidates = candidates.repeat(logits.shape[1], axis=1)
masked = np.ma.masked_array(data=logits, mask=~candidates, fill_value=-100.0)
logits = masked.filled()
return logits
......@@ -9,7 +9,6 @@
- 增加了 `cnocr export-onnx` 命令,把训练好的PyTorch模型导出为ONNX模型;
- 去掉了对包 `python-Levenshtein` 的依赖。
### Update 2021.11.06: 发布 cnocr V2.1.0
主要变更:
......
......@@ -121,8 +121,6 @@ Options:
-h, --help Show this message and exit.
```
## PyTorch 模型导出为 ONNX 模型
把训练好的模型导出为 ONNX 格式。 当前无法导出 `*-gru` 模型, 具体说明见:[Exporting GRU RNN to ONNX - PyTorch Forums](https://discuss.pytorch.org/t/exporting-gru-rnn-to-onnx/27244) 。后续版本会修复此问题。
......@@ -141,5 +139,3 @@ Options:
-o, --output-model-fp TEXT 输出的模型文件路径(.onnx) [required]
-h, --help Show this message and exit.
```
......@@ -10,7 +10,6 @@ CnOcr的目标是**使用简单**。
可以使用 [在线 Demo](demo.md) 查看效果。
## 安装简单
嗯,安装真的很简单。
......@@ -38,6 +37,7 @@ print("Predicted Chars:", res)
```
或:
```python
from cnocr.utils import read_img
from cnocr import CnOcr
......@@ -51,7 +51,6 @@ print("Predicted Chars:", res)
更多说明可见 [使用方法](usage.md)
## 命令行工具
具体见 [命令行工具](command.md)
......@@ -62,23 +61,22 @@ print("Predicted Chars:", res)
## 效果示例
| 图片 | OCR结果 |
| ------------------------------------------------------------ | ------------------------------------------------------------ |
| ![examples/helloworld.jpg](./examples/helloworld.jpg) | Hello world!你好世界 |
| ![examples/chn-00199989.jpg](./examples/chn-00199989.jpg) | 铑泡胭释邑疫反隽寥缔 |
| ![examples/chn-00199980.jpg](./examples/chn-00199980.jpg) | 拇箬遭才柄腾戮胖惬炫 |
| ![examples/chn-00199984.jpg](./examples/chn-00199984.jpg) | 寿猿嗅髓孢刀谎弓供捣 |
| ![examples/chn-00199985.jpg](./examples/chn-00199985.jpg) | 马靼蘑熨距额猬要藕萼 |
| ![examples/chn-00199981.jpg](./examples/chn-00199981.jpg) | 掉江悟厉励.谌查门蠕坑 |
| ![examples/00199975.jpg](./examples/00199975.jpg) | nd-chips fructed ast |
| ![examples/00199978.jpg](./examples/00199978.jpg) | zouna unpayably Raqu |
| ![examples/00199979.jpg](./examples/00199979.jpg) | ape fissioning Senat |
| ![examples/00199971.jpg](./examples/00199971.jpg) | ling oughtlins near |
| ![examples/multi-line_cn1.png](./examples/multi-line_cn1.png) | 网络支付并无本质的区别,因为<br />每一个手机号码和邮件地址背后<br />都会对应着一个账户--这个账<br />户可以是信用卡账户、借记卡账<br />户,也包括邮局汇款、手机代<br />收、电话代收、预付费卡和点卡<br />等多种形式。 |
| ![examples/multi-line_cn2.png](./examples/multi-line_cn2.png) | 当然,在媒介越来越多的情形下,<br />意味着传播方式的变化。过去主流<br />的是大众传播,现在互动性和定制<br />性带来了新的挑战——如何让品牌<br />与消费者更加互动。 |
| 图片 | OCR结果 |
| ----------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| ![examples/helloworld.jpg](./examples/helloworld.jpg) | Hello world!你好世界 |
| ![examples/chn-00199989.jpg](./examples/chn-00199989.jpg) | 铑泡胭释邑疫反隽寥缔 |
| ![examples/chn-00199980.jpg](./examples/chn-00199980.jpg) | 拇箬遭才柄腾戮胖惬炫 |
| ![examples/chn-00199984.jpg](./examples/chn-00199984.jpg) | 寿猿嗅髓孢刀谎弓供捣 |
| ![examples/chn-00199985.jpg](./examples/chn-00199985.jpg) | 马靼蘑熨距额猬要藕萼 |
| ![examples/chn-00199981.jpg](./examples/chn-00199981.jpg) | 掉江悟厉励.谌查门蠕坑 |
| ![examples/00199975.jpg](./examples/00199975.jpg) | nd-chips fructed ast |
| ![examples/00199978.jpg](./examples/00199978.jpg) | zouna unpayably Raqu |
| ![examples/00199979.jpg](./examples/00199979.jpg) | ape fissioning Senat |
| ![examples/00199971.jpg](./examples/00199971.jpg) | ling oughtlins near |
| ![examples/multi-line_cn1.png](./examples/multi-line_cn1.png) | 网络支付并无本质的区别,因为<br />每一个手机号码和邮件地址背后<br />都会对应着一个账户--这个账<br />户可以是信用卡账户、借记卡账<br />户,也包括邮局汇款、手机代<br />收、电话代收、预付费卡和点卡<br />等多种形式。 |
| ![examples/multi-line_cn2.png](./examples/multi-line_cn2.png) | 当然,在媒介越来越多的情形下,<br />意味着传播方式的变化。过去主流<br />的是大众传播,现在互动性和定制<br />性带来了新的挑战——如何让品牌<br />与消费者更加互动。 |
| ![examples/multi-line_en_white.png](./examples/multi-line_en_white.png) | This chapter is currently only available <br />in this web version. ebook and print will follow.<br />Convolutional neural networks learn abstract <br />features and concepts from raw image pixels. Feature<br />Visualization visualizes the learned features <br />by activation maximization. Network Dissection labels<br />neural network units (e.g. channels) with human concepts. |
| ![examples/multi-line_en_black.png](./examples/multi-line_en_black.png) | transforms the image many times. First, the image <br />goes through many convolutional layers. In those<br />convolutional layers, the network learns new <br />and increasingly complex features in its layers. Then the <br />transformed image information goes through <br />the fully connected layers and turns into a classification<br />or prediction. |
| ![examples/multi-line_en_black.png](./examples/multi-line_en_black.png) | transforms the image many times. First, the image <br />goes through many convolutional layers. In those<br />convolutional layers, the network learns new <br />and increasingly complex features in its layers. Then the <br />transformed image information goes through <br />the fully connected layers and turns into a classification<br />or prediction. |
## 其他文档
......@@ -86,7 +84,6 @@ print("Predicted Chars:", res)
* 对于通用场景的文字识别,使用 [文本检测CnStd + 文字识别CnOcr](cnstd_cnocr.md)
* [RELEASE文档](RELEASE.md)
## 未来工作
* [x] 支持图片包含多行文字 (`Done`)
......@@ -99,4 +96,3 @@ print("Predicted Chars:", res)
* [x] 由 MXNet 改为 PyTorch 架构(since `V2.0.0`
* [x] 基于 PyTorch 训练更高效的模型
* [ ] 支持列格式的文字识别
......@@ -14,8 +14,6 @@ pip install cnocr -i https://pypi.doubanio.com/simple
> 注意:请使用 **Python3**(3.6以及之后版本应该都行),没测过Python2下是否ok。
### GPU 环境使用 ONNX 模型
默认情况下安装的 **ONNX** 包是 **`onnxruntime`**,它只能在 `CPU` 上运行。如果需要在 `GPU` 环境使用 **ONNX** 模型,需要卸载此包,然后安装包 **`onnxruntime-gpu`**
......@@ -24,5 +22,3 @@ pip install cnocr -i https://pypi.doubanio.com/simple
pip uninstall onnxruntime
pip install onnxruntime-gpu
```
......@@ -21,8 +21,6 @@ cnocr **V2.1** 目前包含以下可直接使用的模型,训练好的模型
| densenet\_lite\_134-gru | √ | X | 2.9 M | 11 M | 0.9738 | 17.042 |
| densenet\_lite\_136-gru | √ | X | 3.1 M | 12 M | 0.9756 | 17.725 |
一些说明:
1. 模型名称是由**局部编码**模型和**序列编码**模型名称拼接而成,以符合"-"分割。
......
......@@ -55,8 +55,6 @@ Options:
> 注:需要尽量避免过度精调!
# 详细训练教程和训练过程作者答疑
[**模型训练详细教程**](https://articles.zsxq.com/id_u6b4u0wrf46e.html) 见作者的 **知识星球** [CnOCR/CnSTD私享群](https://t.zsxq.com/FEYZRJQ) ,加入私享群后作者也会尽力解答训练过程中遇到的问题。此外,私享群中作者每月提供两次免费特有数据的训练服务。**抱歉的是,私享群不是免费的。**
......@@ -20,7 +20,6 @@
import torch
from cnocr.consts import IMG_STANDARD_HEIGHT, ENG_LETTERS, VOCAB_FP
from cnocr.utils import read_charset, pad_img_seq, load_model_params, read_img, rescale_img, normalize_img_array
from cnocr.models.densenet import DenseNet
from cnocr.models.ocr_model import OcrModel
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册