未验证 提交 cb5807b1 编写于 作者: B BreezeDeus 提交者: GitHub

Merge pull request #160 from breezedeus/dev

v2.1: better models
# 可取值:['densenet-s']
ENCODER_NAME = densenet-s
# 可取值:['densenet_lite_136']
ENCODER_NAME = densenet_lite_136
# 可取值:['fc', 'gru', 'lstm']
DECODER_NAME = gru
DECODER_NAME = fc
MODEL_NAME = $(ENCODER_NAME)-$(DECODER_NAME)
EPOCH = 41
INDEX_DIR = data/test
TRAIN_CONFIG_FP = examples/train_config.json
TRAIN_CONFIG_FP = docs/examples/train_config.json
# 训练模型
train:
......@@ -14,19 +14,28 @@ train:
# 在测试集上评估模型,所有badcases的具体信息会存放到文件夹 `evaluate/$(MODEL_NAME)` 中
evaluate:
python scripts/cnocr_evaluate.py --model-name $(MODEL_NAME) --model-epoch 1 -v -i $(DATA_ROOT_DIR)/test.txt \
--image-prefix-dir examples --batch-size 128 -o evaluate/$(MODEL_NAME)
cnocr evaluate --model-name $(MODEL_NAME) -i data/test/dev.tsv \
--image-folder data/images --batch-size 128 -o eval_results/$(MODEL_NAME)
predict:
cnocr predict -m $(MODEL_NAME) -i examples/rand_cn1.png
cnocr predict -m $(MODEL_NAME) -i docs/examples/rand_cn1.png
doc:
# pip install mkdocs
# pip install mkdocs-macros-plugin
# pip install mkdocs-material
# pip install mkdocstrings
python -m mkdocs serve
# python -m mkdocs build
package:
python setup.py sdist bdist_wheel
VERSION = 2.0.1
VERSION = 2.1.0
upload:
python -m twine upload dist/cnocr-$(VERSION)* --verbose
.PHONY: train evaluate predict package upload
.PHONY: train evaluate predict doc package upload
此差异已折叠。
......@@ -17,4 +17,4 @@
# specific language governing permissions and limitations
# under the License.
__version__ = '2.0.1'
__version__ = '2.1.0'
......@@ -21,15 +21,19 @@ from __future__ import absolute_import, division, print_function
import os
import logging
import time
import click
from collections import Counter
import json
import glob
from operator import itemgetter
from pathlib import Path
import click
import Levenshtein
from torchvision import transforms as T
from cnocr.consts import MODEL_VERSION, ENCODER_CONFIGS, DECODER_CONFIGS
from cnocr.utils import set_logger, load_model_params, check_model_name
from cnocr.data_utils.aug import NormalizeAug, RandomPaddingAug
from cnocr.utils import set_logger, load_model_params, check_model_name, save_img, read_img
from cnocr.data_utils.aug import NormalizeAug, RandomPaddingAug, RandomStretchAug, RandomCrop
from cnocr.dataset import OcrDataModule
from cnocr.trainer import PlTrainer, resave_model
from cnocr import CnOcr, gen_model
......@@ -37,7 +41,7 @@ from cnocr import CnOcr, gen_model
_CONTEXT_SETTINGS = {"help_option_names": ['-h', '--help']}
logger = set_logger(log_level=logging.INFO)
DEFAULT_MODEL_NAME = 'densenet-s-fc'
DEFAULT_MODEL_NAME = 'densenet_lite_136-fc'
LEGAL_MODEL_NAMES = {
enc_name + '-' + dec_name
for enc_name in ENCODER_CONFIGS.keys()
......@@ -54,7 +58,7 @@ def cli():
@click.option(
'-m',
'--model-name',
type=click.Choice(LEGAL_MODEL_NAMES),
type=str,
default=DEFAULT_MODEL_NAME,
help='模型名称。默认值为 %s' % DEFAULT_MODEL_NAME,
)
......@@ -69,7 +73,7 @@ def cli():
'--train-config-fp',
type=str,
required=True,
help='训练使用的json配置文件,参考 `example/train_config.json`',
help='训练使用的json配置文件,参考 `docs/examples/train_config.json`',
)
@click.option(
'-r',
......@@ -92,8 +96,10 @@ def train(
check_model_name(model_name)
train_transform = T.Compose(
[
T.RandomInvert(p=0.5),
T.RandomRotation(degrees=2),
RandomStretchAug(min_ratio=0.5, max_ratio=1.5),
# RandomCrop((8, 10)),
T.RandomInvert(p=0.2),
T.RandomApply([T.RandomRotation(degrees=1)], p=0.4),
# T.RandomAutocontrast(p=0.05),
# T.RandomPosterize(bits=4, p=0.3),
# T.RandomAdjustSharpness(sharpness_factor=0.5, p=0.3),
......@@ -101,7 +107,6 @@ def train(
# T.RandomApply([T.GaussianBlur(kernel_size=3)], p=0.5),
NormalizeAug(),
# RandomPaddingAug(p=0.5, max_pad_len=72),
]
)
val_transform = NormalizeAug()
......@@ -119,6 +124,18 @@ def train(
pin_memory=train_config['pin_memory'],
)
# train_ds = data_mod.train
# for i in range(min(100, len(train_ds))):
# visualize_example(train_transform(train_ds[i][0]), 'debugs/train-1-%d' % i)
# visualize_example(train_transform(train_ds[i][0]), 'debugs/train-2-%d' % i)
# visualize_example(train_transform(train_ds[i][0]), 'debugs/train-3-%d' % i)
# val_ds = data_mod.val
# for i in range(min(10, len(val_ds))):
# visualize_example(val_transform(val_ds[i][0]), 'debugs/val-1-%d' % i)
# visualize_example(val_transform(val_ds[i][0]), 'debugs/val-2-%d' % i)
# visualize_example(val_transform(val_ds[i][0]), 'debugs/val-2-%d' % i)
# return
trainer = PlTrainer(
train_config, ckpt_fn=['cnocr', 'v%s' % MODEL_VERSION, model_name]
)
......@@ -133,20 +150,21 @@ def train(
)
def visualize_example(example, fp_prefix):
if not os.path.exists(os.path.dirname(fp_prefix)):
os.makedirs(os.path.dirname(fp_prefix))
image = example
save_img(image, '%s-image.jpg' % fp_prefix)
@cli.command('predict')
@click.option(
'-m',
'--model-name',
type=click.Choice(LEGAL_MODEL_NAMES),
type=str,
default=DEFAULT_MODEL_NAME,
help='模型名称。默认值为 %s' % DEFAULT_MODEL_NAME,
)
@click.option(
"--model_epoch",
type=int,
default=None,
help="model epoch。默认为 `None`,表示使用系统自带的预训练模型",
)
@click.option(
'-p',
'--pretrained-model-fp',
......@@ -155,6 +173,7 @@ def train(
help='使用训练好的模型。默认为 `None`,表示使用系统自带的预训练模型',
)
@click.option(
"-c",
"--context",
help="使用cpu还是 `gpu` 运行代码,也可指定为特定gpu,如`cuda:0`。默认为 `cpu`",
type=str,
......@@ -167,15 +186,8 @@ def train(
is_flag=True,
help="是否输入图片只包含单行文字。对包含单行文字的图片,不做按行切分;否则会先对图片按行分割后再进行识别",
)
def predict(
model_name, model_epoch, pretrained_model_fp, context, img_file_or_dir, single_line
):
ocr = CnOcr(
model_name=model_name,
model_epoch=model_epoch,
model_fp=pretrained_model_fp,
context=context,
)
def predict(model_name, pretrained_model_fp, context, img_file_or_dir, single_line):
ocr = CnOcr(model_name=model_name, model_fp=pretrained_model_fp, context=context)
ocr_func = ocr.ocr_for_single_line if single_line else ocr.ocr
fp_list = []
if os.path.isfile(img_file_or_dir):
......@@ -197,6 +209,158 @@ def predict(
logger.info('\npred: %s, with probability %f' % (''.join(preds), prob))
@cli.command('evaluate')
@click.option(
'-m',
'--model-name',
type=str,
default=DEFAULT_MODEL_NAME,
help='模型名称。默认值为 %s' % DEFAULT_MODEL_NAME,
)
@click.option(
'-p',
'--pretrained-model-fp',
type=str,
default=None,
help='使用训练好的模型。默认为 `None`,表示使用系统自带的预训练模型',
)
@click.option(
"-c",
"--context",
help="使用cpu还是 `gpu` 运行代码,也可指定为特定gpu,如`cuda:0`。默认为 `cpu`",
type=str,
default='cpu',
)
@click.option(
"-i",
"--eval-index-fp",
type=str,
help='待评估文件所在的索引文件,格式与训练时训练集索引文件相同,每行格式为 `<图片路径>\t<以空格分割的labels>`',
default='test.txt',
)
@click.option("--img-folder", required=True, help="图片所在文件夹,相对于索引文件中记录的图片位置")
@click.option("--batch-size", type=int, help="batch size. 默认值:128", default=128)
@click.option(
'-o',
'--output-dir',
type=str,
default='eval_results',
help='存放评估结果的文件夹。默认值:`eval_results`',
)
@click.option(
"-v", "--verbose", is_flag=True, help="whether to print details to screen",
)
def evaluate(
model_name,
pretrained_model_fp,
context,
eval_index_fp,
img_folder,
batch_size,
output_dir,
verbose,
):
ocr = CnOcr(model_name=model_name, model_fp=pretrained_model_fp, context=context)
fn_labels_list = read_input_file(eval_index_fp)
miss_cnt, redundant_cnt = Counter(), Counter()
total_time_cost = 0.0
bad_cnt = 0
badcases = []
start_idx = 0
while start_idx < len(fn_labels_list):
logger.info('start_idx: %d', start_idx)
batch = fn_labels_list[start_idx : start_idx + batch_size]
img_fps = [os.path.join(img_folder, fn) for fn, _ in batch]
reals = [labels for _, labels in batch]
imgs = [read_img(img) for img in img_fps]
start_time = time.time()
outs = ocr.ocr_for_single_lines(imgs, batch_size=1)
total_time_cost += time.time() - start_time
preds = [out[0] for out in outs]
for bad_info in compare_preds_to_reals(preds, reals, img_fps):
if verbose:
logger.info('\t'.join(bad_info))
distance = Levenshtein.distance(bad_info[1], bad_info[2])
bad_info.insert(0, distance)
badcases.append(bad_info)
miss_cnt.update(list(bad_info[-2]))
redundant_cnt.update(list(bad_info[-1]))
bad_cnt += 1
start_idx += batch_size
badcases.sort(key=itemgetter(0), reverse=True)
output_dir = Path(output_dir)
if not output_dir.exists():
os.makedirs(output_dir)
with open(output_dir / 'badcases.txt', 'w') as f:
f.write(
'\t'.join(
[
'distance',
'image_fp',
'real_words',
'pred_words',
'miss_words',
'redundant_words',
]
)
+ '\n'
)
for bad_info in badcases:
f.write('\t'.join(map(str, bad_info)) + '\n')
with open(output_dir / 'miss_words_stat.txt', 'w') as f:
for word, num in miss_cnt.most_common():
f.write('\t'.join([word, str(num)]) + '\n')
with open(output_dir / 'redundant_words_stat.txt', 'w') as f:
for word, num in redundant_cnt.most_common():
f.write('\t'.join([word, str(num)]) + '\n')
logger.info(
"number of total cases: %d, number of bad cases: %d, acc: %.4f, time cost per image: %f"
% (
len(fn_labels_list),
bad_cnt,
1.0 - bad_cnt / len(fn_labels_list),
total_time_cost / len(fn_labels_list),
)
)
def read_input_file(in_fp):
fn_labels_list = []
with open(in_fp) as f:
for line in f:
fields = line.strip().split('\t')
labels = fields[1].split(' ')
labels = [l if l != '<space>' else ' ' for l in labels]
fn_labels_list.append((fields[0], labels))
return fn_labels_list
def compare_preds_to_reals(batch_preds, batch_reals, batch_img_fns):
for preds, reals, img_fn in zip(batch_preds, batch_reals, batch_img_fns):
if preds == reals:
continue
preds_set, reals_set = set(preds), set(reals)
miss_words = reals_set.difference(preds_set)
redundant_words = preds_set.difference(reals_set)
yield [
img_fn,
''.join(reals),
''.join(preds),
''.join(miss_words),
''.join(redundant_words),
]
@cli.command('resave')
@click.option('-i', '--input-model-fp', type=str, required=True, help='输入的模型文件路径')
@click.option('-o', '--output-model-fp', type=str, required=True, help='输出的模型文件路径')
......
......@@ -48,11 +48,6 @@ logger = logging.getLogger(__name__)
def gen_model(model_name, vocab):
check_model_name(model_name)
if not model_name.startswith('densenet-s'):
logger.warning(
'only "densenet-s" is supported now, use "densenet-s-fc" by default'
)
model_name = 'densenet-s-fc'
model = OcrModel.from_name(model_name, vocab)
return model
......@@ -62,7 +57,7 @@ class CnOcr(object):
def __init__(
self,
model_name: str = 'densenet-s-fc',
model_name: str = 'densenet_lite_136-fc',
*,
cand_alphabet: Optional[Union[Collection, str]] = None,
context: str = 'cpu', # ['cpu', 'gpu', 'cuda']
......@@ -71,14 +66,28 @@ class CnOcr(object):
**kwargs,
):
"""
:param model_name: 模型名称。默认为 `densenet-s-fc`
:param cand_alphabet: 待识别字符所在的候选集合。默认为 `None`,表示不限定识别字符范围
:param context: 'cpu', or 'gpu'。表明预测时是使用CPU还是GPU。默认为 `cpu`
:param model_fp: 如果不使用系统自带的模型,可以通过此参数直接指定所使用的模型文件('.ckpt' 文件)
:param root: 模型文件所在的根目录。
Linux/Mac下默认值为 `~/.cnocr`,表示模型文件所处文件夹类似 `~/.cnocr/2.0/densenet-s-fc`。
识别模型初始化函数。
Args:
model_name (str): 模型名称。默认为 `densenet_lite_136-fc`
cand_alphabet (Optional[Union[Collection, str]]): 待识别字符所在的候选集合。默认为 `None`,表示不限定识别字符范围
context (str): 'cpu', or 'gpu'。表明预测时是使用CPU还是GPU。默认为 `cpu`
model_fp (Optional[str]): 如果不使用系统自带的模型,可以通过此参数直接指定所使用的模型文件('.ckpt' 文件)
root (Union[str, Path]): 模型文件所在的根目录。
Linux/Mac下默认值为 `~/.cnocr`,表示模型文件所处文件夹类似 `~/.cnocr/2.1/densenet_lite_136-fc`。
Windows下默认值为 `C:/Users/<username>/AppData/Roaming/cnocr`。
**kwargs: 目前未被使用。
Examples:
使用默认参数:
>>> ocr = CnOcr()
使用指定模型:
>>> ocr = CnOcr(model_name='densenet_lite_136-fc')
识别时只考虑数字:
>>> ocr = CnOcr(model_name='densenet_lite_136-fc', cand_alphabet='0123456789')
"""
if 'name' in kwargs:
logger.warning(
......@@ -144,8 +153,13 @@ class CnOcr(object):
def set_cand_alphabet(self, cand_alphabet: Optional[Union[Collection, str]]):
"""
设置待识别字符的候选集合。
:param cand_alphabet: 待识别字符所在的候选集合。默认为 `None`,表示不限定识别字符范围
:return: None
Args:
cand_alphabet (Optional[Union[Collection, str]]): 待识别字符所在的候选集合。默认为 `None`,表示不限定识别字符范围
Returns:
None
"""
if cand_alphabet is None:
self._candidates = None
......@@ -169,10 +183,15 @@ class CnOcr(object):
self, img_fp: Union[str, Path, torch.Tensor, np.ndarray]
) -> List[Tuple[List[str], float]]:
"""
:param img_fp: image file path; or color image torch.Tensor or np.ndarray,
识别函数。
Args:
img_fp (Union[str, Path, torch.Tensor, np.ndarray]): image file path; or color image torch.Tensor or np.ndarray,
with shape [height, width] or [height, width, channel].
channel should be 1 (gray image) or 3 (RGB formatted color image). scaled in [0, 255].
:return: list of (list of chars, prob), such as
Returns:
list of (list of chars, prob), such as
[(['第', '一', '行'], 0.80), (['第', '二', '行'], 0.75), (['第', '三', '行'], 0.9)]
"""
img = self._prepare_img(img_fp)
......@@ -190,11 +209,15 @@ class CnOcr(object):
self, img_fp: Union[str, Path, torch.Tensor, np.ndarray]
) -> np.ndarray:
"""
:param img: image array with type torch.Tensor or np.ndarray,
Args:
img_fp (Union[str, Path, torch.Tensor, np.ndarray]):
image array with type torch.Tensor or np.ndarray,
with shape [height, width] or [height, width, channel].
channel should be 1 (gray image) or 3 (color image).
:return: np.ndarray, with shape (height, width, 1), dtype uint8, scale [0, 255]
Returns:
np.ndarray: with shape (height, width, 1), dtype uint8, scale [0, 255]
"""
img = img_fp
if isinstance(img_fp, (str, Path)):
......@@ -226,10 +249,15 @@ class CnOcr(object):
) -> Tuple[List[str], float]:
"""
Recognize characters from an image with only one-line characters.
:param img_fp: image file path; or image torch.Tensor or np.ndarray,
Args:
img_fp (Union[str, Path, torch.Tensor, np.ndarray]):
image file path; or image torch.Tensor or np.ndarray,
with shape [height, width] or [height, width, channel].
The optional channel should be 1 (gray image) or 3 (color image).
:return: (list of chars, prob), such as (['你', '好'], 0.80)
Returns:
tuple: (list of chars, prob), such as (['你', '好'], 0.80)
"""
img = self._prepare_img(img_fp)
res = self.ocr_for_single_lines([img])
......@@ -242,27 +270,49 @@ class CnOcr(object):
) -> List[Tuple[List[str], float]]:
"""
Batch recognize characters from a list of one-line-characters images.
:param img_list: list of images, in which each element should be a line image array,
Args:
img_list (List[Union[str, Path, torch.Tensor, np.ndarray]]):
list of images, in which each element should be a line image array,
with type torch.Tensor or np.ndarray.
Each element should be a tensor with values ranging from 0 to 255,
and with shape [height, width] or [height, width, channel].
The optional channel should be 1 (gray image) or 3 (color image).
:param batch_size: 待处理图片很多时,需要分批处理,每批图片的数量由此参数指定。默认为 `1`。
:return: list of (list of chars, prob), such as
注:img_list 不宜包含太多图片,否则同时导入这些图片会消耗很多内存。
batch_size: 待处理图片很多时,需要分批处理,每批图片的数量由此参数指定。默认为 `1`。
Returns:
list: list of (list of chars, prob), such as
[(['第', '一', '行'], 0.80), (['第', '二', '行'], 0.75), (['第', '三', '行'], 0.9)]
"""
if len(img_list) == 0:
return []
img_list = [self._prepare_img(img) for img in img_list]
img_list = [self._transform_img(img) for img in img_list]
should_sort = batch_size > 1 and len(img_list) // batch_size > 1
if should_sort:
# 把图片按宽度从小到大排列,提升效率
sorted_idx_list = sorted(
range(len(img_list)), key=lambda i: img_list[i].shape[2]
)
sorted_img_list = [img_list[i] for i in sorted_idx_list]
else:
sorted_idx_list = range(len(img_list))
sorted_img_list = img_list
idx = 0
out = []
while idx * batch_size < len(img_list):
imgs = img_list[idx * batch_size : (idx + 1) * batch_size]
sorted_out = []
while idx * batch_size < len(sorted_img_list):
imgs = sorted_img_list[idx * batch_size : (idx + 1) * batch_size]
batch_out = self._predict(imgs)
out.extend(batch_out['preds'])
sorted_out.extend(batch_out['preds'])
idx += 1
out = [None] * len(sorted_out)
for idx, pred in zip(sorted_idx_list, sorted_out):
out[idx] = pred
res = []
for line in out:
......@@ -274,11 +324,13 @@ class CnOcr(object):
def _transform_img(self, img: np.ndarray) -> torch.Tensor:
"""
:param img: image array with type torch.Tensor or np.ndarray,
Args:
img: image array with type torch.Tensor or np.ndarray,
with shape [height, width] or [height, width, channel].
channel shoule be 1 (gray image) or 3 (color image).
:return: torch.Tensor, with shape (1, height, width)
Returns:
torch.Tensor: with shape (1, height, width)
"""
img = rescale_img(img.transpose((2, 0, 1))) # res: [C, H, W]
return NormalizeAug()(img).to(device=torch.device(self.context))
......
......@@ -30,38 +30,89 @@ IMG_STANDARD_HEIGHT = 32
VOCAB_FP = Path(__file__).parent / 'label_cn.txt'
ENCODER_CONFIGS = {
'densenet-s': { # 长度压缩至 1/8(seq_len == 35),输出的向量长度为 4*128 = 512
'densenet': { # 长度压缩至 1/8(seq_len == 35),输出的向量长度为 4*128 = 512
'growth_rate': 32,
'block_config': [2, 2, 2, 2],
'num_init_features': 64,
'out_length': 512, # 输出的向量长度为 4*128 = 512
},
'densenet_1112': { # 长度压缩至 1/8(seq_len == 35)
'growth_rate': 32,
'block_config': [1, 1, 1, 2],
'num_init_features': 64,
'out_length': 400,
},
'densenet_1114': { # 长度压缩至 1/8(seq_len == 35)
'growth_rate': 32,
'block_config': [1, 1, 1, 4],
'num_init_features': 64,
'out_length': 656,
},
'densenet_1122': { # 长度压缩至 1/8(seq_len == 35)
'growth_rate': 32,
'block_config': [1, 1, 2, 2],
'num_init_features': 64,
'out_length': 464,
},
'densenet_1124': { # 长度压缩至 1/8(seq_len == 35)
'growth_rate': 32,
'block_config': [1, 1, 2, 4],
'num_init_features': 64,
'out_length': 720,
},
'densenet_lite_113': { # 长度压缩至 1/8(seq_len == 35),输出的向量长度为 2*136 = 272
'growth_rate': 32,
'block_config': [1, 1, 3],
'num_init_features': 64,
'out_length': 272, # 输出的向量长度为 2*80 = 160
},
'densenet_lite_114': { # 长度压缩至 1/8(seq_len == 35)
'growth_rate': 32,
'block_config': [1, 1, 4],
'num_init_features': 64,
'out_length': 336,
},
'densenet_lite_124': { # 长度压缩至 1/8(seq_len == 35)
'growth_rate': 32,
'block_config': [1, 2, 4],
'num_init_features': 64,
'out_length': 368,
},
'densenet_lite_134': { # 长度压缩至 1/8(seq_len == 35)
'growth_rate': 32,
'block_config': [1, 3, 4],
'num_init_features': 64,
'out_length': 400,
},
'densenet_lite_136': { # 长度压缩至 1/8(seq_len == 35)
'growth_rate': 32,
'block_config': [1, 3, 6],
'num_init_features': 64,
'out_length': 528,
},
'mobilenetv3_tiny': {'arch': 'tiny', 'out_length': 384,},
'mobilenetv3_small': {'arch': 'small', 'out_length': 384,},
}
DECODER_CONFIGS = {
'lstm': {
'input_size': 512, # 对应 encoder 的输出向量长度
'rnn_units': 128,
},
'gru': {
'input_size': 512, # 对应 encoder 的输出向量长度
'rnn_units': 128,
},
'fc': {
'input_size': 512, # 对应 encoder 的输出向量长度
'hidden_size': 256,
'dropout': 0.3,
}
'lstm': {'rnn_units': 128,},
'gru': {'rnn_units': 128,},
'fc': {'hidden_size': 128, 'dropout': 0.1,},
'fcfull': {'hidden_size': 256, 'dropout': 0.3,},
}
root_url = (
'https://beiye-model.oss-cn-beijing.aliyuncs.com/models/cnocr/%s/'
'https://huggingface.co/breezedeus/cnstd-cnocr-models/resolve/main/models/cnocr/%s/'
% MODEL_VERSION
)
# name: (epochs, url)
# name: (epoch, url)
AVAILABLE_MODELS = {
'densenet-s-fc': (8, root_url + 'densenet-s-fc-v2.0.1.zip'),
'densenet-s-gru': (14, root_url + 'densenet-s-gru-v2.0.1.zip'),
'densenet_lite_114-fc': (37, root_url + 'densenet_lite_114-fc.zip'),
'densenet_lite_124-fc': (39, root_url + 'densenet_lite_124-fc.zip'),
'densenet_lite_134-fc': (34, root_url + 'densenet_lite_134-fc.zip'),
'densenet_lite_136-fc': (39, root_url + 'densenet_lite_136-fc.zip'),
'densenet_lite_134-gru': (2, root_url + 'densenet_lite_134-gru.zip'),
'densenet_lite_136-gru': (2, root_url + 'densenet_lite_136-gru.zip'),
}
# 候选字符集合
......
......@@ -18,8 +18,10 @@
# under the License.
import random
from typing import Tuple
import torch
import torchvision.transforms.functional as F
from ..utils import normalize_img_array
......@@ -32,6 +34,7 @@ class FgBgFlipAug(object):
p : float
Probability to flip image horizontally
"""
def __init__(self, p):
self.p = p
......@@ -47,6 +50,71 @@ class NormalizeAug(object):
return normalize_img_array(img)
class RandomStretchAug(object):
"""对图片在宽度上做随机拉伸"""
def __init__(self, min_ratio=0.9, max_ratio=1.1):
self.min_ratio = min_ratio
self.max_ratio = max_ratio
def __call__(self, img: torch.Tensor):
"""
:param img: [C, H, W]
:return:
"""
_, h, w = img.shape
new_w_ratio = self.min_ratio + random.random() * (
self.max_ratio - self.min_ratio
)
return F.resize(img, [h, int(w * new_w_ratio)])
class RandomCrop(torch.nn.Module):
def __init__(
self, crop_size: Tuple[int, int], interpolation=F.InterpolationMode.BILINEAR
):
super().__init__()
self.crop_size = crop_size
self.interpolation = interpolation
def get_params(self, ori_w, ori_h) -> Tuple[int, int, int, int]:
"""Get parameters for ``crop`` for a random crop.
Args:
Returns:
tuple: params (i, j, h, w) to be passed to ``crop`` for random crop.
"""
while True:
h_top, h_bot = (
random.randint(0, self.crop_size[0]),
random.randint(0, self.crop_size[0]),
)
w_left, w_right = (
random.randint(0, self.crop_size[1]),
random.randint(0, self.crop_size[1]),
)
h = ori_h - h_top - h_bot
w = ori_w - w_left - w_right
if h < ori_h * 0.5 or w < ori_w * 0.9:
continue
return h_top, w_left, h, w
def forward(self, img):
"""
Args:
img (PIL Image or Tensor): Image to be cropped and resized.
Returns:
PIL Image or Tensor: Randomly cropped and resized image.
"""
ori_w, ori_h = F._get_image_size(img)
i, j, h, w = self.get_params(ori_w, ori_h)
return F.resized_crop(img, i, j, h, w, (ori_h, ori_w), self.interpolation)
class RandomPaddingAug(object):
def __init__(self, p, max_pad_len):
self.p = p
......
# coding: utf-8
# Copyright (C) 2021, [Breezedeus](https://github.com/breezedeus).
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# Credits: adapted from https://mp.weixin.qq.com/s/xGvaW87UQFjetc5xFmKxWg
import random
from torch.utils.data import Dataset, DataLoader
class BlockShuffleDataLoader(DataLoader):
def __init__(
self, dataset: Dataset, **kwargs
):
"""
对 OcrDataset 数据集实现Block Shuffle功能,按文字数量从少到多的顺序排列样本(相同长度样本则随机排列)
Args:
dataset: OcrDataset类的实例,其中中必须包含labels_list变量,并且该变量为一个list
**kwargs:
"""
assert isinstance(
dataset.labels_list, list
), "dataset为OcrDataset类的实例,其中必须包含labels_list变量,并且该变量为一个list"
kwargs['shuffle'] = False
super().__init__(dataset, **kwargs)
def __iter__(self):
self.block_shuffle2()
return super().__iter__()
def block_shuffle2(self):
idx_list = list(range(len(self.dataset)))
random.shuffle(idx_list)
random.shuffle(idx_list)
idx_list.sort(key=lambda idx: len(self.dataset.labels_list[idx]))
for attr in ('img_fp_list', 'labels_list'):
ori_list = getattr(self.dataset, attr)
new_list = [ori_list[idx] for idx in idx_list]
setattr(self.dataset, attr, new_list)
......@@ -59,9 +59,9 @@ def collate_fn(img_labels: List[Tuple[str, str]], transformers: Callable = None)
img_list, labels_list = zip(*img_labels)
label_lengths = torch.tensor([len(labels) for labels in labels_list])
img_lengths = torch.tensor([img.size(2) for img in img_list])
if transformers is not None:
img_list = [transformers(img) for img in img_list]
img_lengths = torch.tensor([img.size(2) for img in img_list])
imgs = pad_img_seq(img_list)
return imgs, img_lengths, labels_list, label_lengths
......
# coding: utf-8
# Copyright (C) 2021, [Breezedeus](https://github.com/breezedeus).
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from copy import deepcopy
import math
import torch
from torch.optim.lr_scheduler import (
_LRScheduler,
StepLR,
LambdaLR,
CyclicLR,
CosineAnnealingWarmRestarts,
MultiStepLR,
OneCycleLR,
)
def get_lr_scheduler(config, optimizer):
orig_lr = config['learning_rate']
lr_sch_config = deepcopy(config['lr_scheduler'])
lr_sch_name = lr_sch_config.pop('name')
epochs = config['epochs']
steps_per_epoch = config['steps_per_epoch']
if lr_sch_name == 'multi_step':
milestones = [v * steps_per_epoch for v in lr_sch_config['milestones']]
return MultiStepLR(
optimizer, milestones=milestones, gamma=lr_sch_config['gamma'],
)
elif lr_sch_name == 'cos_warmup':
min_lr_mult_factor = lr_sch_config.get('min_lr_mult_factor', 0.1)
warmup_epochs = lr_sch_config.get('warmup_epochs', 0.1)
return WarmupCosineAnnealingRestarts(
optimizer,
first_cycle_steps=steps_per_epoch * epochs,
max_lr=orig_lr,
min_lr=orig_lr * min_lr_mult_factor,
warmup_steps=int(steps_per_epoch * warmup_epochs),
)
elif lr_sch_name == 'cos_anneal':
# 5 个 epochs, 一个循环
return CosineAnnealingWarmRestarts(
optimizer, T_0=5 * steps_per_epoch, T_mult=1, eta_min=orig_lr * 0.1
)
elif lr_sch_name == 'cyclic':
return CyclicLR(
optimizer,
base_lr=orig_lr / 10.0,
max_lr=orig_lr,
step_size_up=5 * steps_per_epoch, # 5 个 epochs, 从最小base_lr上升到最大max_lr
cycle_momentum=False,
)
elif lr_sch_name == 'one_cycle':
return OneCycleLR(
optimizer, max_lr=orig_lr, epochs=epochs, steps_per_epoch=steps_per_epoch,
)
step_size = lr_sch_config['step_size']
gamma = lr_sch_config['gamma']
if step_size is None or gamma is None:
return LambdaLR(optimizer, lr_lambda=lambda _: 1)
return StepLR(optimizer, step_size, gamma=gamma)
class WarmupCosineAnnealingRestarts(_LRScheduler):
"""
from https://github.com/katsura-jp/pytorch-cosine-annealing-with-warmup/blob/master/cosine_annealing_warmup/scheduler.py
optimizer (Optimizer): Wrapped optimizer.
first_cycle_steps (int): First cycle step size.
cycle_mult(float): Cycle steps magnification. Default: -1.
max_lr(float): First cycle's max learning rate. Default: 0.1.
min_lr(float): Min learning rate. Default: 0.001.
warmup_steps(int): Linear warmup step size. Default: 0.
gamma(float): Decrease rate of max learning rate by cycle. Default: 1.
last_epoch (int): The index of last epoch. Default: -1.
"""
def __init__(
self,
optimizer: torch.optim.Optimizer,
first_cycle_steps: int,
cycle_mult: float = 1.0,
max_lr: float = 0.1,
min_lr: float = 0.001,
warmup_steps: int = 0,
gamma: float = 1.0,
last_epoch: int = -1,
):
assert warmup_steps < first_cycle_steps
self.first_cycle_steps = first_cycle_steps # first cycle step size
self.cycle_mult = cycle_mult # cycle steps magnification
self.base_max_lr = max_lr # first max learning rate
self.max_lr = max_lr # max learning rate in the current cycle
self.min_lr = min_lr # min learning rate
self.warmup_steps = warmup_steps # warmup step size
self.gamma = gamma # decrease rate of max learning rate by cycle
self.cur_cycle_steps = first_cycle_steps # first cycle step size
self.cycle = 0 # cycle count
self.step_in_cycle = last_epoch # step size of the current cycle
super().__init__(optimizer, last_epoch)
# set learning rate min_lr
self.init_lr()
def init_lr(self):
self.base_lrs = []
for param_group in self.optimizer.param_groups:
param_group['lr'] = self.min_lr
self.base_lrs.append(self.min_lr)
def get_lr(self):
if self.step_in_cycle == -1:
return self.base_lrs
elif self.step_in_cycle < self.warmup_steps:
return [
(self.max_lr - base_lr) * self.step_in_cycle / self.warmup_steps
+ base_lr
for base_lr in self.base_lrs
]
else:
return [
base_lr
+ (self.max_lr - base_lr)
* (
1
+ math.cos(
math.pi
* (self.step_in_cycle - self.warmup_steps)
/ (self.cur_cycle_steps - self.warmup_steps)
)
)
/ 2
for base_lr in self.base_lrs
]
def step(self, epoch=None):
if epoch is None:
epoch = self.last_epoch + 1
self.step_in_cycle = self.step_in_cycle + 1
if self.step_in_cycle >= self.cur_cycle_steps:
self.cycle += 1
self.step_in_cycle = self.step_in_cycle - self.cur_cycle_steps
self.cur_cycle_steps = (
int((self.cur_cycle_steps - self.warmup_steps) * self.cycle_mult)
+ self.warmup_steps
)
else:
if epoch >= self.first_cycle_steps:
if self.cycle_mult == 1.0:
self.step_in_cycle = epoch % self.first_cycle_steps
self.cycle = epoch // self.first_cycle_steps
else:
n = int(
math.log(
(
epoch / self.first_cycle_steps * (self.cycle_mult - 1)
+ 1
),
self.cycle_mult,
)
)
self.cycle = n
self.step_in_cycle = epoch - int(
self.first_cycle_steps
* (self.cycle_mult ** n - 1)
/ (self.cycle_mult - 1)
)
self.cur_cycle_steps = self.first_cycle_steps * self.cycle_mult ** (
n
)
else:
self.cur_cycle_steps = self.first_cycle_steps
self.step_in_cycle = epoch
self.max_lr = self.base_max_lr * (self.gamma ** self.cycle)
self.last_epoch = math.floor(epoch)
for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()):
param_group['lr'] = lr
......@@ -29,7 +29,7 @@ class DenseNet(densenet.DenseNet):
def __init__(
self,
growth_rate: int = 32,
block_config: Tuple[int, int, int, int] = (6, 12, 24, 16),
block_config: Tuple[int, int, int, int] = (2, 2, 2, 2),
num_init_features: int = 64,
bn_size: int = 4,
drop_rate: float = 0,
......@@ -46,13 +46,36 @@ class DenseNet(densenet.DenseNet):
)
self.block_config = block_config
delattr(self, 'classifier')
self.features.conv0 = nn.Conv2d(
1, num_init_features, kernel_size=3, stride=1, padding=1, bias=False
)
self.features.pool0 = nn.MaxPool2d(kernel_size=3, stride=1, padding=1)
delattr(self, 'classifier')
last_denselayer = self._get_last_denselayer(len(self.block_config))
conv = last_denselayer.conv2
in_channels, out_channels = conv.in_channels, conv.out_channels
last_denselayer.conv2 = nn.Conv2d(
in_channels, out_channels, kernel_size=5, stride=1, padding=2, bias=False
)
# for i in range(1, len(self.block_config)):
# transition = getattr(self.features, 'transition%d' % i)
# in_channels, out_channels = transition.conv.in_channels, transition.conv.out_channels
# trans = _MaxPoolTransition(num_input_features=in_channels,
# num_output_features=out_channels)
# setattr(self.features, 'transition%d' % i, trans)
self._post_init_weights()
def _get_last_denselayer(self, block_num):
denseblock = getattr(self.features, 'denseblock%d' % block_num)
i = 1
while hasattr(denseblock, 'denselayer%d' % i):
i += 1
return getattr(denseblock, 'denselayer%d' % (i-1))
@property
def compress_ratio(self):
return 2 ** (len(self.block_config) - 1)
......@@ -71,3 +94,51 @@ class DenseNet(densenet.DenseNet):
def forward(self, x: Tensor) -> Tensor:
features = self.features(x)
return features
class DenseNetLite(DenseNet):
def __init__(
self,
growth_rate: int = 32,
block_config: Tuple[int, int, int] = (2, 2, 2),
num_init_features: int = 64,
bn_size: int = 4,
drop_rate: float = 0,
memory_efficient: bool = False,
) -> None:
super().__init__(
growth_rate,
block_config,
num_init_features,
bn_size,
drop_rate,
memory_efficient=memory_efficient,
)
self.features.pool0 = nn.AvgPool2d(kernel_size=2, stride=2)
# last max pool, pool 1/8 to 1/16 for height dimension
self.features.add_module(
'pool5', nn.AvgPool2d(kernel_size=(2, 1), stride=(2, 1))
)
@property
def compress_ratio(self):
return 2 ** len(self.block_config)
class _MaxPoolTransition(nn.Sequential):
def __init__(self, num_input_features: int, num_output_features: int) -> None:
super().__init__()
self.add_module('norm', nn.BatchNorm2d(num_input_features))
self.add_module('relu', nn.ReLU(inplace=True))
self.add_module(
'conv',
nn.Conv2d(
num_input_features,
num_output_features,
kernel_size=1,
stride=1,
bias=False,
),
)
self.add_module('pool', nn.MaxPool2d(kernel_size=2, stride=2))
# coding: utf-8
# Copyright (C) 2021, [Breezedeus](https://github.com/breezedeus).
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# adapted from: torchvision/models/mobilenetv3.py
from functools import partial
from typing import Any, List, Optional, Callable
from torch import nn, Tensor
from torchvision.models.mobilenetv2 import ConvBNActivation
from torchvision.models import mobilenetv3
from torchvision.models.mobilenetv3 import InvertedResidualConfig
class MobileNetV3(mobilenetv3.MobileNetV3):
def __init__(
self,
inverted_residual_setting: List[InvertedResidualConfig],
block: Optional[Callable[..., nn.Module]] = None,
norm_layer: Optional[Callable[..., nn.Module]] = None,
**kwargs: Any
) -> None:
super().__init__(inverted_residual_setting, 1, 2, block, norm_layer)
delattr(self, 'classifier')
firstconv_input_channels = self.features[0][0].out_channels
self.features[0] = ConvBNActivation(
1,
firstconv_input_channels,
kernel_size=3,
stride=2,
norm_layer=norm_layer,
activation_layer=nn.Hardswish,
)
lastconv_input_channels = self.features[-1][0].in_channels
lastconv_output_channels = 2 * lastconv_input_channels
self.features[-1] = ConvBNActivation(
lastconv_input_channels,
lastconv_output_channels,
kernel_size=1,
norm_layer=norm_layer,
activation_layer=nn.Hardswish,
)
self.avgpool = nn.AvgPool2d(kernel_size=(2, 1), stride=(2, 1))
self._post_init_weights()
@property
def compress_ratio(self):
return 8
def _post_init_weights(self):
# Official init from torch repo.
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out')
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.ones_(m.weight)
nn.init.zeros_(m.bias)
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, 0, 0.01)
nn.init.zeros_(m.bias)
def forward(self, x: Tensor) -> Tensor:
features = self.features(x)
features = self.avgpool(features)
return features
def _mobilenet_v3_conf(
arch: str,
width_mult: float = 1.0,
reduced_tail: bool = False,
dilated: bool = False,
**kwargs: Any
):
reduce_divider = 2 if reduced_tail else 1
dilation = 2 if dilated else 1
bneck_conf = partial(InvertedResidualConfig, width_mult=width_mult)
adjust_channels = partial(
InvertedResidualConfig.adjust_channels, width_mult=width_mult
)
if arch == "mobilenet_v3_tiny":
inverted_residual_setting = [
bneck_conf(24, 3, 88, 24, False, "RE", 1, 1),
bneck_conf(24, 5, 96, 40, False, "HS", 2, 1), # C3
bneck_conf(40, 5, 120, 48, False, "HS", 1, 1),
# bneck_conf(48, 5, 144, 48, False, "HS", 1, 1),
bneck_conf(
48, 5, 288, 96 // reduce_divider, False, "HS", 2, dilation
), # C4
bneck_conf(
96 // reduce_divider,
5,
128 // reduce_divider,
96 // reduce_divider,
True,
"HS",
1,
dilation,
),
bneck_conf(
96 // reduce_divider,
5,
128 // reduce_divider,
96 // reduce_divider,
True,
"HS",
1,
dilation,
),
]
elif arch == "mobilenet_v3_small":
inverted_residual_setting = [
bneck_conf(16, 3, 16, 16, False, "RE", 1, 1), # C1
bneck_conf(16, 3, 72, 24, False, "RE", 1, 1), # C2
bneck_conf(24, 3, 88, 24, False, "RE", 1, 1),
bneck_conf(24, 5, 96, 40, False, "HS", 2, 1), # C3
bneck_conf(40, 5, 240, 40, False, "HS", 1, 1),
bneck_conf(40, 5, 240, 40, False, "HS", 1, 1),
bneck_conf(40, 5, 120, 48, True, "HS", 1, 1),
bneck_conf(48, 5, 144, 48, True, "HS", 1, 1),
bneck_conf(48, 5, 288, 96 // reduce_divider, True, "HS", 2, dilation), # C4
bneck_conf(
96 // reduce_divider,
5,
576 // reduce_divider,
96 // reduce_divider,
True,
"HS",
1,
dilation,
),
bneck_conf(
96 // reduce_divider,
5,
576 // reduce_divider,
96 // reduce_divider,
True,
"HS",
1,
dilation,
),
]
else:
raise ValueError("Unsupported model type {}".format(arch))
return inverted_residual_setting
def _mobilenet_v3_model(
inverted_residual_setting: List[InvertedResidualConfig], **kwargs: Any
):
model = MobileNetV3(inverted_residual_setting, **kwargs)
return model
def gen_mobilenet_v3(arch: str = 'tiny', **kwargs: Any) -> MobileNetV3:
"""
Constructs a small MobileNetV3 architecture from
`"Searching for MobileNetV3" <https://arxiv.org/abs/1905.02244>`_.
Args:
arch (str): arch name; values: 'tiny' or 'small'
"""
arch = 'mobilenet_v3_%s' % arch
inverted_residual_setting = _mobilenet_v3_conf(arch, **kwargs)
return _mobilenet_v3_model(inverted_residual_setting, **kwargs)
......@@ -30,7 +30,8 @@ from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
from .ctc import CTCPostProcessor
from ..consts import ENCODER_CONFIGS, DECODER_CONFIGS
from ..data_utils.utils import encode_sequences
from .densenet import DenseNet
from .densenet import DenseNet, DenseNetLite
from .mobilenet import gen_mobilenet_v3
class EncoderManager(object):
......@@ -45,9 +46,16 @@ class EncoderManager(object):
assert config is not None and 'name' in config
name = config.pop('name')
if name.lower() == 'densenet-s':
if name.lower().startswith('densenet_lite'):
out_length = config.pop('out_length')
encoder = DenseNetLite(**config)
elif name.lower().startswith('densenet'):
out_length = config.pop('out_length')
encoder = DenseNet(**config)
elif name.lower().startswith('mobilenet'):
arch = config['arch']
out_length = config.pop('out_length')
encoder = gen_mobilenet_v3(arch)
else:
raise ValueError('not supported encoder name: %s' % name)
return encoder, out_length
......@@ -86,11 +94,11 @@ class DecoderManager(object):
bidirectional=True,
)
out_length = config['rnn_units'] * 2
elif name.lower() == 'fc':
elif name.lower() in ('fc', 'fcfull'):
decoder = nn.Sequential(
nn.Dropout(p=config['dropout']),
# nn.Tanh(),
nn.Linear(config['input_size'], config['hidden_size']),
nn.Linear(input_size, config['hidden_size']),
nn.Dropout(p=config['dropout']),
nn.Tanh(),
)
......
......@@ -26,16 +26,12 @@ import numpy as np
import torch
import torch.optim as optim
from torch import nn
from torch.utils.data import DataLoader
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor
from torch.optim.lr_scheduler import (
StepLR,
LambdaLR,
CyclicLR,
CosineAnnealingWarmRestarts,
MultiStepLR,
)
from torch.utils.data import DataLoader
from .lr_scheduler import get_lr_scheduler
logger = logging.getLogger(__name__)
......@@ -67,37 +63,6 @@ def get_optimizer(name: str, model, learning_rate, weight_decay):
return optimizer
def get_lr_scheduler(config, optimizer):
orig_lr = config['learning_rate']
lr_sch_config = deepcopy(config['lr_scheduler'])
lr_sch_name = lr_sch_config.pop('name')
if lr_sch_name == 'multi_step':
return MultiStepLR(
optimizer,
milestones=lr_sch_config['milestones'],
gamma=lr_sch_config['gamma'],
)
elif lr_sch_name == 'cos_anneal':
return CosineAnnealingWarmRestarts(
optimizer, T_0=4, T_mult=1, eta_min=orig_lr / 10.0
)
elif lr_sch_name == 'cyclic':
return CyclicLR(
optimizer,
base_lr=orig_lr / 10.0,
max_lr=orig_lr,
step_size_up=2,
cycle_momentum=False,
)
step_size = lr_sch_config['step_size']
gamma = lr_sch_config['gamma']
if step_size is None or gamma is None:
return LambdaLR(optimizer, lr_lambda=lambda _: 1)
return StepLR(optimizer, step_size, gamma=gamma)
class Accuracy(object):
@classmethod
def complete_match(cls, labels: List[List[str]], preds: List[List[str]]):
......@@ -144,6 +109,11 @@ class WrapperLightningModule(pl.LightningModule):
else:
setattr(self.model, 'current_epoch', self.current_epoch)
res = self.model.calculate_loss(batch)
# update lr scheduler
sch = self.lr_schedulers()
sch.step()
losses = res['loss']
self.log(
'train_loss',
......@@ -214,7 +184,7 @@ class PlTrainer(object):
max_epochs=self.config.get('epochs', 20),
precision=self.config.get('precision', 32),
callbacks=callbacks,
stochastic_weight_avg=True,
stochastic_weight_avg=False,
)
def fit(
......@@ -241,6 +211,12 @@ class PlTrainer(object):
no checkpoint file at the path, start from scratch. If resuming from mid-epoch checkpoint,
training will start from the beginning of the next epoch.
"""
steps_per_epoch = (
len(train_dataloader)
if train_dataloader is not None
else len(datamodule.train_dataloader())
)
self.config['steps_per_epoch'] = steps_per_epoch
if resume_from_checkpoint is not None:
pl_module = WrapperLightningModule.load_from_checkpoint(
resume_from_checkpoint, config=self.config, model=model
......
......@@ -102,7 +102,7 @@ def data_dir():
def check_model_name(model_name):
encoder_type, decoder_type = model_name.rsplit('-', maxsplit=1)
encoder_type, decoder_type = model_name.split('-')[:2]
assert encoder_type in ENCODER_CONFIGS
assert decoder_type in DECODER_CONFIGS
......@@ -362,3 +362,9 @@ def load_model_params(model, param_fp, device='cpu'):
state_dict[k.split('.', maxsplit=1)[1]] = v
model.load_state_dict(state_dict)
return model
def get_model_size(model, only_trainable=False):
if only_trainable:
return sum(p.numel() for p in model.parameters() if p.requires_grad)
return sum(p.numel() for p in model.parameters())
# Release Notes
### Update 2021.11.06: 发布 cnocr V2.1.0
主要变更:
* 使用了更精简的模型架构:`densenet_lite_*`
* 使用了更丰富的数据重新训练了所有模型,精度相较于之前版本更高;
* 提供了更多预训练好的模型;
* 加入了 `cnocr evaluate` 命令以评估效果。
### Update 2021.09.21: 发布 cnocr V2.0.1
主要变更:
......
# 强强联合:[CnStd](https://github.com/breezedeus/cnstd) + CnOcr
关于为什么要结合 [CnStd](https://github.com/breezedeus/cnstd) 和 CnOcr 一起使用,可参考 [场景文字识别介绍](std_ocr.md)
对于一般的场景图片(如照片、票据等),需要先利用场景文字检测引擎 **[cnstd](https://github.com/breezedeus/cnstd)** 定位到文字所在位置,然后再利用 **cnocr** 进行文本识别。
```python
from cnstd import CnStd
from cnocr import CnOcr
std = CnStd()
cn_ocr = CnOcr()
box_infos = std.detect('examples/taobao.jpg')
for box_info in box_infos['detected_texts']:
cropped_img = box_info['cropped_img']
ocr_res = cn_ocr.ocr_for_single_line(cropped_img)
print('ocr result: %s' % str(ocr_out))
```
注:运行上面示例需要先安装 **[cnstd](https://github.com/breezedeus/cnstd)**
```bash
pip install cnstd
```
**[cnstd](https://github.com/breezedeus/cnstd)** 相关的更多使用说明请参考其项目地址。
可基于 [在线 Demo](demo.md) 查看 CnStd + CnOcr 的联合效果。
# 脚本使用
**cnocr** 包含了几个命令行工具,安装 **cnocr** 后即可使用。
## 预测单个文件或文件夹中所有图片
使用命令 **`cnocr predict`** 预测单个文件或文件夹中所有图片,以下是使用说明:
```bash
(venv) ➜ cnocr git:(dev) ✗ cnocr predict -h
Usage: cnocr predict [OPTIONS]
Options:
-m, --model-name TEXT 模型名称。默认值为 densenet_lite_136-fc
-p, --pretrained-model-fp TEXT 使用训练好的模型。默认为 `None`,表示使用系统自带的预训练模型
-c, --context TEXT 使用cpu还是 `gpu` 运行代码,也可指定为特定gpu,如`cuda:0`。默认为
`cpu`
-i, --img-file-or-dir TEXT 输入图片的文件路径或者指定的文件夹 [required]
-s, --single-line 是否输入图片只包含单行文字。对包含单行文字的图片,不做按行切分;否则会先对图片按行分割后
再进行识别
-h, --help Show this message and exit.
```
例如可以使用以下命令对图片 `docs/examples/rand_cn1.png` 进行文字识别:
```bash
cnstd predict -i docs/examples/rand_cn1.png -s
```
具体使用也可参考文件 [Makefile](https://github.com/breezedeus/cnocr/blob/master/Makefile)
## 模型评估
使用命令 **`cnocr evaluate`** 在指定的数据集上评估模型效果,以下是使用说明:
```bash
(venv) ➜ cnocr git:(dev) ✗ cnocr evaluate -h
Usage: cnocr evaluate [OPTIONS]
Options:
-m, --model-name TEXT 模型名称。默认值为 densenet_lite_136-fc
-p, --pretrained-model-fp TEXT 使用训练好的模型。默认为 `None`,表示使用系统自带的预训练模型
-c, --context TEXT 使用cpu还是 `gpu` 运行代码,也可指定为特定gpu,如`cuda:0`。默认为
`cpu`
-i, --eval-index-fp TEXT 待评估文件所在的索引文件,格式与训练时训练集索引文件相同,每行格式为 `<图片路径>
<以空格分割的labels>`
--img-folder TEXT 图片所在文件夹,相对于索引文件中记录的图片位置 [required]
--batch-size INTEGER batch size. 默认值:`128`
-o, --output-dir TEXT 存放评估结果的文件夹。默认值:`eval_results`
-v, --verbose whether to print details to screen
-h, --help Show this message and exit.
```
例如可以使用以下命令评估 `data/test/dev.tsv` 中指定的所有样本:
```bash
cnocr evaluate -i data/test/dev.tsv --image-folder data/images
```
具体使用也可参考文件 [Makefile](https://github.com/breezedeus/cnocr/blob/master/Makefile)
## 模型训练
使用命令 **`cnocr train`** 训练文本检测模型,以下是使用说明:
```bash
(venv) ➜ cnocr git:(dev) ✗ cnocr train -h
Usage: cnocr train [OPTIONS]
Options:
-m, --model-name TEXT 模型名称。默认值为 densenet_lite_136-fc
-i, --index-dir TEXT 索引文件所在的文件夹,会读取文件夹中的 train.tsv 和 dev.tsv 文件
[required]
--train-config-fp TEXT 训练使用的json配置文件,参考
`docs/examples/train_config.json`
[required]
-r, --resume-from-checkpoint TEXT
恢复此前中断的训练状态,继续训练。默认为 `None`
-p, --pretrained-model-fp TEXT 导入的训练好的模型,作为初始模型。优先级低于"--restore-training-
fp",当传入"--restore-training-fp"时,此传入失效。默认为
`None`
-h, --help Show this message and exit.
```
例如可以使用以下命令进行训练:
```bash
cnocr train -m densenet_lite_136-fc --index-dir data/test --train-config-fp docs/examples/train_config.json
```
训练数据的格式见文件夹 [data/test](https://github.com/breezedeus/cnocr/blob/master/data/test) 中的 [train.tsv](https://github.com/breezedeus/cnocr/blob/master/data/test/train.tsv)[dev.tsv](https://github.com/breezedeus/cnocr/blob/master/data/test/dev.tsv) 文件。
具体使用也可参考文件 [Makefile](https://github.com/breezedeus/cnocr/blob/master/Makefile)
## 模型转存
训练好的模型会存储训练状态,使用命令 **`cnocr resave`** 去掉与预测无关的数据,降低模型大小。
```bash
(venv) ➜ cnocr git:(pytorch) ✗ cnocr resave -h
Usage: cnocr resave [OPTIONS]
训练好的模型会存储训练状态,使用此命令去掉预测时无关的数据,降低模型大小
Options:
-i, --input-model-fp TEXT 输入的模型文件路径 [required]
-o, --output-model-fp TEXT 输出的模型文件路径 [required]
-h, --help Show this message and exit.
```
# QQ 交流群
欢迎扫码加入QQ交流群:
![QQ群二维码](./cnocr-qq.jpg)
# 在线 Demo
地址:[https://share.streamlit.io/breezedeus/cnstd/st-deploy/cnstd/app.py](https://share.streamlit.io/breezedeus/cnstd/st-deploy/cnstd/app.py)
![Demo](figs/demo.jpg)
\ No newline at end of file
{
"vocab_fp": "label_cn.txt",
"vocab_fp": "cnocr/label_cn.txt",
"img_folder": "data/images",
"gpus": 0,
"epochs": 2,
"batch_size": 64,
"epochs": 20,
"batch_size": 4,
"num_workers": 0,
"pin_memory": false,
"optimizer": "adam",
"learning_rate": 1e-3,
"weight_decay": 0,
"lr_scheduler": {
"name": "multi_step",
"milestones": [5, 10, 16, 22, 30],
"gamma": 0.5
"name": "cos_warmup",
"min_lr_mult_factor": 0.01,
"warmup_epochs": 0.2
},
"precision": 32,
"limit_train_batches": 1.0,
......
{
"vocab_fp": "label_cn.txt",
"img_folder": "data/images",
"vocab_fp": "cnocr/label_cn.txt",
"img_folder": "data/output_normal",
"gpus": [0],
"epochs": 40,
"batch_size": 200,
"batch_size": 100,
"num_workers": 12,
"pin_memory": true,
"optimizer": "adam",
"learning_rate": 3e-3,
"weight_decay": 0,
"lr_scheduler": {
"name": "multi_step",
"name": "cos_warmup",
"min_lr_mult_factor": 0.01,
"warmup_epochs": 0.2,
"milestones": [5, 10, 16, 22, 30],
"gamma": 0.5
},
......
# 常见问题(FAQ)
## CnOcr 是免费的吗?
CnOcr是免费的,而且是开源的。可以按需自行调整发布或商业使用。
## CnOcr 能识别英文以及空格吗?
可以。
## CnOcr 能识别繁体中文吗?
不能。
## CnOcr 能识别竖排文字的图片吗?
不能。
# CnOcr
**[CnOcr](https://github.com/breezedeus/cnocr)****Python 3** 下的**文字识别****Optical Character Recognition**,简称**OCR**)工具包,
支持**中文****英文**的常见字符识别,自带了多个[训练好的识别模型](models.md),安装后即可直接使用。
欢迎扫码加入[QQ交流群](contact.md)
CnOcr的目标是**使用简单**
可以使用 [在线 Demo](demo.md) 查看效果。
## 安装简单
嗯,安装真的很简单。
```bash
pip install cnocr
```
更多说明可见 [安装文档](install.md)
## 使用简单
使用 `CnOcr.ocr()` 识别下图:
![多行文字图片](examples/multi-line_cn1.png)
**调用示例**
```python
from cnocr import CnOcr
ocr = CnOcr()
res = ocr.ocr('examples/multi-line_cn1.png')
print("Predicted Chars:", res)
```
或:
```python
from cnocr.utils import read_img
from cnocr import CnOcr
ocr = CnOcr()
img_fp = 'examples/multi-line_cn1.png'
img = read_img(img_fp)
res = ocr.ocr(img)
print("Predicted Chars:", res)
```
更多说明可见 [使用方法](usage.md)
## 命令行工具
具体见 [命令行工具](command.md)
### 训练自己的模型
具体见 [模型训练](train.md)
## 效果示例
| 图片 | OCR结果 |
| ------------------------------------------------------------ | ------------------------------------------------------------ |
| ![examples/helloworld.jpg](./examples/helloworld.jpg) | Hello world!你好世界 |
| ![examples/chn-00199989.jpg](./examples/chn-00199989.jpg) | 铑泡胭释邑疫反隽寥缔 |
| ![examples/chn-00199980.jpg](./examples/chn-00199980.jpg) | 拇箬遭才柄腾戮胖惬炫 |
| ![examples/chn-00199984.jpg](./examples/chn-00199984.jpg) | 寿猿嗅髓孢刀谎弓供捣 |
| ![examples/chn-00199985.jpg](./examples/chn-00199985.jpg) | 马靼蘑熨距额猬要藕萼 |
| ![examples/chn-00199981.jpg](./examples/chn-00199981.jpg) | 掉江悟厉励.谌查门蠕坑 |
| ![examples/00199975.jpg](./examples/00199975.jpg) | nd-chips fructed ast |
| ![examples/00199978.jpg](./examples/00199978.jpg) | zouna unpayably Raqu |
| ![examples/00199979.jpg](./examples/00199979.jpg) | ape fissioning Senat |
| ![examples/00199971.jpg](./examples/00199971.jpg) | ling oughtlins near |
| ![examples/multi-line_cn1.png](./examples/multi-line_cn1.png) | 网络支付并无本质的区别,因为<br />每一个手机号码和邮件地址背后<br />都会对应着一个账户--这个账<br />户可以是信用卡账户、借记卡账<br />户,也包括邮局汇款、手机代<br />收、电话代收、预付费卡和点卡<br />等多种形式。 |
| ![examples/multi-line_cn2.png](./examples/multi-line_cn2.png) | 当然,在媒介越来越多的情形下,<br />意味着传播方式的变化。过去主流<br />的是大众传播,现在互动性和定制<br />性带来了新的挑战——如何让品牌<br />与消费者更加互动。 |
| ![examples/multi-line_en_white.png](./examples/multi-line_en_white.png) | This chapter is currently only available <br />in this web version. ebook and print will follow.<br />Convolutional neural networks learn abstract <br />features and concepts from raw image pixels. Feature<br />Visualization visualizes the learned features <br />by activation maximization. Network Dissection labels<br />neural network units (e.g. channels) with human concepts. |
| ![examples/multi-line_en_black.png](./examples/multi-line_en_black.png) | transforms the image many times. First, the image <br />goes through many convolutional layers. In those<br />convolutional layers, the network learns new <br />and increasingly complex features in its layers. Then the <br />transformed image information goes through <br />the fully connected layers and turns into a classification<br />or prediction. |
## 其他文档
* [场景文字识别技术介绍(PPT+视频)](std_ocr.md)
* 对于通用场景的文字识别,使用 [文本检测CnStd + 文字识别CnOcr](cnstd_cnocr.md)
* [RELEASE文档](RELEASE.md)
## 未来工作
* [x] 支持图片包含多行文字 (`Done`)
* [x] crnn模型支持可变长预测,提升灵活性 (since `V1.0.0`)
* [x] 完善测试用例 (`Doing`)
* [x] 修bugs(目前代码还比较凌乱。。) (`Doing`)
* [x] 支持`空格`识别(since `V1.1.0`
* [x] 尝试新模型,如 DenseNet,进一步提升识别准确率(since `V1.1.0`
* [x] 优化训练集,去掉不合理的样本;在此基础上,重新训练各个模型
* [x] 由 MXNet 改为 PyTorch 架构(since `V2.0.0`
* [ ] 基于 PyTorch 训练更高效的模型
* [ ] 支持列格式的文字识别
## 安装
嗯,安装真的很简单。
```bash
pip install cnocr
```
安装速度慢的话,可以指定国内的安装源,如使用豆瓣源:
```bash
pip install cnocr -i https://pypi.doubanio.com/simple
```
> 注意:请使用 **Python3**(3.6以及之后版本应该都行),没测过Python2下是否ok。
## 可直接使用的模型
cnocr的ocr模型可以分为两阶段:第一阶段是获得ocr图片的局部编码向量,第二部分是对局部编码向量进行序列学习,获得序列编码向量。目前的PyTorch版本的两个阶段分别包含以下模型:
1. 局部编码模型(emb model)
* **`densenet_lite_<numbers>`**:一个微型的`densenet`网络;其中的`<number>`表示模型中每个block包含的层数。
* **`densenet`**:一个小型的`densenet`网络;
2. 序列编码模型(seq model)
* **`fc`**:两层的全连接网络;
* **`gru`**:一层的GRU网络;
* **`lstm`**:一层的LSTM网络。
cnocr **V2.1** 目前包含以下可直接使用的模型,训练好的模型都放在 **[cnstd-cnocr-models](https://github.com/breezedeus/cnstd-cnocr-models)** 项目中,可免费下载使用:
| Name | 参数规模 | 模型文件大小 | 准确度 | 平均推断耗时(毫秒/图) |
| --- | --- | --- | --- | --- |
| densenet\_lite\_114-fc | 1.3 M | 4.9 M | 0.9274 | 9.229 |
| densenet\_lite\_124-fc | 1.3 M | 5.1 M | 0.9429 | 10.112 |
| densenet\_lite\_134-fc | 1.4 M | 5.4 M | 0.954 | 10.843 |
| densenet\_lite\_136-fc | 1.5M | 5.9 M | 0.9631 | 11.499 |
| densenet\_lite\_134-gru | 2.9 M | 11 M | 0.9738 | 17.042 |
| densenet\_lite\_136-gru | 3.1 M | 12 M | 0.9756 | 17.725 |
> 模型名称是由局部编码模型和序列编码模型名称拼接而成,以符合"-"分割。
#
# This file is autogenerated by pip-compile
# To update, run:
#
# pip-compile --output-file=requirements.txt requirements.in
#
--index-url https://pypi.doubanio.com/simple
absl-py==0.13.0 # via tensorboard
aiohttp==3.7.4.post0 # via fsspec
async-timeout==3.0.1 # via aiohttp
attrs==21.2.0 # via aiohttp
cachetools==4.2.2 # via google-auth
certifi==2020.4.5.1 # via requests
chardet==3.0.4 # via aiohttp, requests
click==8.0.1 # via -r requirements.in
fsspec[http]==2021.7.0 # via pytorch-lightning
future==0.18.2 # via pytorch-lightning
google-auth-oauthlib==0.4.5 # via tensorboard
google-auth==1.35.0 # via google-auth-oauthlib, tensorboard
grpcio==1.39.0 # via tensorboard
idna==2.9 # via requests, yarl
markdown==3.3.4 # via tensorboard
multidict==5.1.0 # via aiohttp, yarl
numpy==1.18.3 # via -r requirements.in, pytorch-lightning, tensorboard, torchmetrics, torchvision
oauthlib==3.1.1 # via requests-oauthlib
packaging==21.0 # via pytorch-lightning, torchmetrics
pillow==5.3.0 # via -r requirements.in, torchvision
protobuf==3.17.3 # via tensorboard
pyasn1-modules==0.2.8 # via google-auth
pyasn1==0.4.8 # via pyasn1-modules, rsa
pydeprecate==0.3.1 # via pytorch-lightning
pyparsing==2.4.7 # via packaging
pytorch-lightning==1.4.4 # via -r requirements.in
pyyaml==5.4.1 # via pytorch-lightning
requests-oauthlib==1.3.0 # via google-auth-oauthlib
requests==2.23.0 # via fsspec, requests-oauthlib, tensorboard
rsa==4.7.2 # via google-auth
six==1.14.0 # via absl-py, google-auth, grpcio, protobuf
tensorboard-data-server==0.6.1 # via tensorboard
tensorboard-plugin-wit==1.8.0 # via tensorboard
tensorboard==2.6.0 # via pytorch-lightning
torch==1.9.0 # via -r requirements.in, pytorch-lightning, torchmetrics, torchvision
torchmetrics==0.5.0 # via pytorch-lightning
torchvision==0.10.0 # via -r requirements.in
tqdm==4.45.0 # via -r requirements.in, pytorch-lightning
typing-extensions==3.10.0.0 # via aiohttp, pytorch-lightning, torch
urllib3==1.25.9 # via requests
werkzeug==2.0.1 # via tensorboard
wheel==0.37.0 # via tensorboard
yarl==1.6.3 # via aiohttp
# The following packages are considered to be unsafe in a requirements file:
# setuptools
# for mkdocs
mkdocs
mkdocs-macros-plugin
mkdocs-material
mkdocstrings
/* Sidebar */
.md-sidebar {
width: 10rem;
}
/* Indenting docstrings */
div.doc-contents:not(.first) {
padding-left: 25px;
border-left: 4px solid #e6e6e6;
margin-bottom: 80px;
}
/* Functions inside classes */
.md-typeset h5 {
font-size: 0.8rem;
text-transform: none !important;
}
/* Code highlights (custom) */
.highlight .hll { background-color: #ffffcc }
.highlight .c { color: #408080; font-style: italic } /* Comment */
.highlight .err { border: 1px solid #FF0000 } /* Error */
.highlight .k { color: #008000; font-weight: bold } /* Keyword */
.highlight .o { color: #AE2FFE } /* Operator */
.highlight .cm { color: #408080; font-style: italic } /* Comment.Multiline */
.highlight .cp { color: #BC7A00 } /* Comment.Preproc */
.highlight .c1 { color: #408080; font-style: italic } /* Comment.Single */
.highlight .cs { color: #408080; font-style: italic } /* Comment.Special */
.highlight .gd { color: #A00000 } /* Generic.Deleted */
.highlight .ge { font-style: italic } /* Generic.Emph */
.highlight .gr { color: #FF0000 } /* Generic.Error */
.highlight .gh { color: #000080; font-weight: bold } /* Generic.Heading */
.highlight .gi { color: #00A000 } /* Generic.Inserted */
.highlight .go { color: #808080 } /* Generic.Output */
.highlight .gp { color: #000080; font-weight: bold } /* Generic.Prompt */
.highlight .gs { font-weight: bold } /* Generic.Strong */
.highlight .gu { color: #800080; font-weight: bold } /* Generic.Subheading */
.highlight .gt { color: #0040D0 } /* Generic.Traceback */
.highlight .kc { color: #008000; font-weight: bold } /* Keyword.Constant */
.highlight .kd { color: #008000; font-weight: bold } /* Keyword.Declaration */
.highlight .kn { color: #008000; font-weight: bold } /* Keyword.Namespace */
.highlight .kp { color: #008000 } /* Keyword.Pseudo */
.highlight .kr { color: #008000; font-weight: bold } /* Keyword.Reserved */
.highlight .kt { color: #B00040 } /* Keyword.Type */
.highlight .m { color: #008000 } /* Literal.Number */
.highlight .s { color: #BA2121 } /* Literal.String */
.highlight .na { color: #7D9029 } /* Name.Attribute */
.highlight .nb { color: #008000 } /* Name.Builtin */
.highlight .nc { color: #0000FF; font-weight: bold } /* Name.Class */
.highlight .no { color: #880000 } /* Name.Constant */
.highlight .nd { color: #AA22FF } /* Name.Decorator */
.highlight .ni { color: #999999; font-weight: bold } /* Name.Entity */
.highlight .ne { color: #D2413A; font-weight: bold } /* Name.Exception */
.highlight .nf { color: #0000FF } /* Name.Function */
.highlight .nl { color: #A0A000 } /* Name.Label */
.highlight .nn { color: #0000FF; font-weight: bold } /* Name.Namespace */
.highlight .nt { color: #008000; font-weight: bold } /* Name.Tag */
.highlight .nv { color: #19177C } /* Name.Variable */
.highlight .ow { color: #AA22FF; font-weight: bold } /* Operator.Word */
.highlight .w { color: #bbbbbb } /* Text.Whitespace */
.highlight .mf { color: #008000 } /* Literal.Number.Float */
.highlight .mh { color: #008000 } /* Literal.Number.Hex */
.highlight .mi { color: #008000 } /* Literal.Number.Integer */
.highlight .mo { color: #008000 } /* Literal.Number.Oct */
.highlight .sb { color: #BA2121 } /* Literal.String.Backtick */
.highlight .sc { color: #BA2121 } /* Literal.String.Char */
.highlight .sd { color: #BA2121; font-style: italic } /* Literal.String.Doc */
.highlight .s2 { color: #BA2121 } /* Literal.String.Double */
.highlight .se { color: #BB6622; font-weight: bold } /* Literal.String.Escape */
.highlight .sh { color: #BA2121 } /* Literal.String.Heredoc */
.highlight .si { color: #BB6688; font-weight: bold } /* Literal.String.Interpol */
.highlight .sx { color: #008000 } /* Literal.String.Other */
.highlight .sr { color: #BB6688 } /* Literal.String.Regex */
.highlight .s1 { color: #BA2121 } /* Literal.String.Single */
.highlight .ss { color: #19177C } /* Literal.String.Symbol */
.highlight .bp { color: #008000 } /* Name.Builtin.Pseudo */
.highlight .vc { color: #19177C } /* Name.Variable.Class */
.highlight .vg { color: #19177C } /* Name.Variable.Global */
.highlight .vi { color: #19177C } /* Name.Variable.Instance */
.highlight .il { color: #008000 } /* Literal.Number.Integer.Long */
# 场景文字识别技术介绍
为了识别一张图片中的文字,通常包含两个步骤:
1. **文本检测**:检测出图片中文字所在的位置;
2. **文字识别**:识别包含文字的图片局部,预测具体的文字。
如下图:
![文字识别流程](figs/std-ocr.jpg)
更多相关介绍可参考作者分享:**文本检测与识别**[PPT](intro-cnstd-cnocr.pdf)[B站视频](https://www.bilibili.com/video/BV1uU4y1N7Ba))。
---
cnocr 主要功能是上面的第二步,也即文字识别。有些应用场景(如下图的文字截图图片等),待检测的图片背景很简单,如白色或其他纯色,
cnocr 内置的文字检测和分行模块可以处理这种简单场景。
![文字截图图片](examples/multi-line_cn1.png)
但如果用于其他复杂的场景文字图片(如下图)的识别,
cnocr 需要结合其他的场景文字检测引擎使用,推荐文字检测引擎 **[CnStd](https://github.com/breezedeus/cnstd)**
![复杂场景文字图片](examples/taobao4.jpg)
具体使用方式,可参考 [文本检测CnStd + 文字识别CnOcr](cnstd_cnocr.md)
# 模型训练
自带模型基于 `500+万` 的文字图片训练而成。
## 训练命令
[命令行工具](command.md) 介绍了训练命令。使用命令 **`cnocr train`** 训练文本检测模型,以下是使用说明:
```bash
(venv) ➜ cnocr git:(dev) ✗ cnocr train -h
Usage: cnocr train [OPTIONS]
Options:
-m, --model-name TEXT 模型名称。默认值为 densenet_lite_136-fc
-i, --index-dir TEXT 索引文件所在的文件夹,会读取文件夹中的 train.tsv 和 dev.tsv 文件
[required]
--train-config-fp TEXT 训练使用的json配置文件,参考
`docs/examples/train_config.json`
[required]
-r, --resume-from-checkpoint TEXT
恢复此前中断的训练状态,继续训练。默认为 `None`
-p, --pretrained-model-fp TEXT 导入的训练好的模型,作为初始模型。优先级低于"--restore-training-
fp",当传入"--restore-training-fp"时,此传入失效。默认为
`None`
-h, --help Show this message and exit.
```
例如可以使用以下命令进行训练:
```bash
cnocr train -m densenet_lite_136-fc --index-dir data/test --train-config-fp docs/examples/train_config.json
```
训练数据的格式见文件夹 [data/test](https://github.com/breezedeus/cnocr/blob/master/data/test) 中的 [train.tsv](https://github.com/breezedeus/cnocr/blob/master/data/test/train.tsv)[dev.tsv](https://github.com/breezedeus/cnocr/blob/master/data/test/dev.tsv) 文件。
具体使用也可参考文件 [Makefile](https://github.com/breezedeus/cnocr/blob/master/Makefile)
# 模型精调
如果需要在已有模型的基础上精调模型,需要把训练配置中的学习率设置的较小,`lr_scheduler`的设置可参考以下:
```json
"learning_rate": 3e-5,
"lr_scheduler": {
"name": "cos_warmup",
"min_lr_mult_factor": 0.01,
"warmup_epochs": 2
},
```
> 注:需要尽量避免过度精调!
# 使用方法
## 模型文件自动下载
首次使用cnocr时,系统会**自动下载** zip格式的模型压缩文件,并存于 `~/.cnocr`目录(Windows下默认路径为 `C:\Users\<username>\AppData\Roaming\cnocr`)。
下载后的zip文件代码会自动对其解压,然后把解压后的模型相关目录放于`~/.cnocr/2.1`目录中。
如果系统无法自动成功下载zip文件,则需要手动从 **[cnstd-cnocr-models](https://huggingface.co/breezedeus/cnstd-cnocr-models/tree/main)** 下载此zip文件并把它放于 `~/.cnocr/2.1`目录。如果下载太慢,也可以从 [百度云盘](https://pan.baidu.com/s/1N6HoYearUzU0U8NTL3K35A) 下载, 提取码为 ` gcig`
放置好zip文件后,后面的事代码就会自动执行了。
## 预测代码
### 针对多行文字的图片识别
如果待识别的图片包含多行文字,或者可能包含多行文字(如下图),可以使用 `CnOcr.ocr()` 进行识别。
![多行文字图片](examples/multi-line_cn1.png)
**调用示例**
```python
from cnocr import CnOcr
ocr = CnOcr()
res = ocr.ocr('docs/examples/multi-line_cn1.png')
print("Predicted Chars:", res)
```
或:
```python
from cnocr.utils import read_img
from cnocr import CnOcr
ocr = CnOcr()
img_fp = 'docs/examples/multi-line_cn1.png'
img = read_img(img_fp)
res = ocr.ocr(img)
print("Predicted Chars:", res)
```
### 针对单行文字的图片识别
如果明确知道待识别的图片包含单行文字(如下图),可以使用 `CnOcr.ocr_for_single_line()` 进行识别。
![单行文字图片](examples/helloworld.jpg)
**调用示例**
```python
from cnocr import CnOcr
ocr = CnOcr()
res = ocr.ocr_for_single_line('docs/examples/helloworld.jpg')
print("Predicted Chars:", res)
```
或:
```python
from cnocr.utils import read_img
from cnocr import CnOcr
ocr = CnOcr()
img_fp = 'docs/examples/helloworld.jpg'
img = read_img(img_fp)
res = ocr.ocr_for_single_line(img)
print("Predicted Chars:", res)
```
## 效果示例
| 图片 | OCR结果 |
| ------------------------------------------------------------ | ------------------------------------------------------------ |
| ![examples/helloworld.jpg](./examples/helloworld.jpg) | Hello world!你好世界 |
| ![examples/chn-00199989.jpg](./examples/chn-00199989.jpg) | 铑泡胭释邑疫反隽寥缔 |
| ![examples/chn-00199980.jpg](./examples/chn-00199980.jpg) | 拇箬遭才柄腾戮胖惬炫 |
| ![examples/chn-00199984.jpg](./examples/chn-00199984.jpg) | 寿猿嗅髓孢刀谎弓供捣 |
| ![examples/chn-00199985.jpg](./examples/chn-00199985.jpg) | 马靼蘑熨距额猬要藕萼 |
| ![examples/chn-00199981.jpg](./examples/chn-00199981.jpg) | 掉江悟厉励.谌查门蠕坑 |
| ![examples/00199975.jpg](./examples/00199975.jpg) | nd-chips fructed ast |
| ![examples/00199978.jpg](./examples/00199978.jpg) | zouna unpayably Raqu |
| ![examples/00199979.jpg](./examples/00199979.jpg) | ape fissioning Senat |
| ![examples/00199971.jpg](./examples/00199971.jpg) | ling oughtlins near |
| ![examples/multi-line_cn1.png](./examples/multi-line_cn1.png) | 网络支付并无本质的区别,因为<br />每一个手机号码和邮件地址背后<br />都会对应着一个账户--这个账<br />户可以是信用卡账户、借记卡账<br />户,也包括邮局汇款、手机代<br />收、电话代收、预付费卡和点卡<br />等多种形式。 |
| ![examples/multi-line_cn2.png](./examples/multi-line_cn2.png) | 当然,在媒介越来越多的情形下,<br />意味着传播方式的变化。过去主流<br />的是大众传播,现在互动性和定制<br />性带来了新的挑战——如何让品牌<br />与消费者更加互动。 |
| ![examples/multi-line_en_white.png](./examples/multi-line_en_white.png) | This chapter is currently only available <br />in this web version. ebook and print will follow.<br />Convolutional neural networks learn abstract <br />features and concepts from raw image pixels. Feature<br />Visualization visualizes the learned features <br />by activation maximization. Network Dissection labels<br />neural network units (e.g. channels) with human concepts. |
| ![examples/multi-line_en_black.png](./examples/multi-line_en_black.png) | transforms the image many times. First, the image <br />goes through many convolutional layers. In those<br />convolutional layers, the network learns new <br />and increasingly complex features in its layers. Then the <br />transformed image information goes through <br />the fully connected layers and turns into a classification<br />or prediction. |
## 详细使用说明
[类CnOcr](cnocr/cn_ocr.md) 是识别主类,包含了三个函数针对不同场景进行文字识别。类`CnOcr`的初始化函数如下:
```python
class CnOcr(object):
def __init__(
self,
model_name: str = 'densenet_lite_136-fc'
*,
cand_alphabet: Optional[Union[Collection, str]] = None,
context: str = 'cpu', # ['cpu', 'gpu', 'cuda']
model_fp: Optional[str] = None,
root: Union[str, Path] = data_dir(),
**kwargs,
):
```
其中的几个参数含义如下:
* `model_name`: 模型名称,即上面表格第一列中的值。默认为 `densenet_lite_136-fc`
* `cand_alphabet`: 待识别字符所在的候选集合。默认为 `None`,表示不限定识别字符范围。取值可以是字符串,如 `"0123456789"`,或者字符列表,如 `["0", "1", "2", "3", "4", "5", "6", "7", "8", "9"]`
* `cand_alphabet`也可以初始化后通过类函数 `CnOcr.set_cand_alphabet(cand_alphabet)` 进行设置。这样同一个实例也可以指定不同的`cand_alphabet`进行识别。
* `context`:预测使用的机器资源,可取值为字符串`cpu``gpu``cuda:0`等。
* `model_fp`: 如果不使用系统自带的模型,可以通过此参数直接指定所使用的模型文件(`.ckpt` 文件)。
* `root`: 模型文件所在的根目录。
* Linux/Mac下默认值为 `~/.cnocr`,表示模型文件所处文件夹类似 `~/.cnocr/2.1/densenet_lite_136-fc`
* Windows下默认值为 `C:\Users\<username>\AppData\Roaming\cnocr`
每个参数都有默认取值,所以可以不传入任何参数值进行初始化:`ocr = CnOcr()`
---
`CnOcr`主要包含三个函数,下面分别说明。
### 1. 函数`CnOcr.ocr(img_fp)`
函数`CnOcr.ocr(img_fp)`可以对包含多行文字(或单行)的图片进行文字识别。
**函数说明**
- 输入参数 `img_fp`: 可以是需要识别的图片文件路径(如下例);或者是已经从图片文件中读入的数组,类型可以为 `torch.Tensor``np.ndarray`,取值应该是`[0,255]`的整数,维数应该是 `[height, width]` (灰度图片)或者 `[height, width, channel]``channel` 可以等于`1`(灰度图片)或者`3``RGB`格式的彩色图片)。
- 返回值:为一个嵌套的`list`,其中的每个元素存储了对一行文字的识别结果,其中也包含了识别概率值。类似这样`[(['第', '一', '行'], 0.80), (['第', '二', '行'], 0.75), (['第', '三', '行'], 0.9)]`,其中的数字为对应的识别概率值。
**调用示例**
```python
from cnocr import CnOcr
ocr = CnOcr()
res = ocr.ocr('examples/multi-line_cn1.png')
print("Predicted Chars:", res)
```
或:
```python
from cnocr.utils import read_img
from cnocr import CnOcr
ocr = CnOcr()
img_fp = 'examples/multi-line_cn1.png'
img = read_img(img_fp)
res = ocr.ocr(img)
print("Predicted Chars:", res)
```
上面使用的图片文件 [docs/examples/multi-line_cn1.png](./examples/multi-line_cn1.png)内容如下:
![examples/multi-line_cn1.png](./examples/multi-line_cn1.png)
上面预测代码段的返回结果如下:
```bash
Predicted Chars: [
(['网', '络', '支', '付', '并', '无', '本', '质', '的', '区', '别', ',', '因', '为'], 0.8677546381950378),
(['每', '一', '个', '手', '机', '号', '码', '和', '邮', '件', '地', '址', '背', '后'], 0.6706454157829285),
(['都', '会', '对', '应', '着', '一', '个', '账', '户', '一', '一', '这', '个', '账'], 0.5052655935287476),
(['户', '可', '以', '是', '信', '用', '卡', '账', '户', '、', '借', '记', '卡', '账'], 0.7785991430282593),
(['户', ',', '也', '包', '括', '邮', '局', '汇', '款', '、', '手', '机', '代'], 0.37458470463752747),
(['收', '、', '电', '话', '代', '收', '、', '预', '付', '费', '卡', '和', '点', '卡'], 0.7326119542121887),
(['等', '多', '种', '形', '式', '。'], 0.14462216198444366)]
```
### 2. 函数`CnOcr.ocr_for_single_line(img_fp)`
如果明确知道要预测的图片中只包含了单行文字,可以使用函数`CnOcr.ocr_for_single_line(img_fp)`进行识别。和 `CnOcr.ocr()`相比,`CnOcr.ocr_for_single_line()`结果可靠性更强,因为它不需要做额外的分行处理。
**函数说明**
- 输入参数 `img_fp`: 可以是需要识别的图片文件路径(如下例);或者是已经从图片文件中读入的数组,类型可以为 `torch.Tensor``np.ndarray`,取值应该是`[0,255]`的整数,维数应该是 `[height, width]` (灰度图片)或者 `[height, width, channel]``channel` 可以等于`1`(灰度图片)或者`3``RGB`格式的彩色图片)。
- 返回值:为一个`tuple`,其中存储了对一行文字的识别结果,也包含了识别概率值。类似这样`(['第', '一', '行'], 0.80)`,其中的数字为对应的识别概率值。
**调用示例**
```python
from cnocr import CnOcr
ocr = CnOcr()
res = ocr.ocr_for_single_line('examples/rand_cn1.png')
print("Predicted Chars:", res)
```
或:
```python
from cnocr.utils import read_img
from cnocr import CnOcr
ocr = CnOcr()
img_fp = 'examples/rand_cn1.png'
img = read_img(img_fp)
res = ocr.ocr_for_single_line(img)
print("Predicted Chars:", res)
```
对图片文件 [docs/examples/rand_cn1.png](./examples/rand_cn1.png)
![examples/rand_cn1.png](./examples/rand_cn1.png)
的预测结果如下:
```bash
Predicted Chars: (['笠', '淡', '嘿', '骅', '谧', '鼎', '皋', '姚', '歼', '蠢', '驼', '耳', '胬', '挝', '涯', '狗', '蒽', '了', '狞'], 0.7832438349723816)
```
### 3. 函数`CnOcr.ocr_for_single_lines(img_list, batch_size=1)`
函数`CnOcr.ocr_for_single_lines(img_list)`可以**对多个单行文字图片进行批量预测**。函数`CnOcr.ocr(img_fp)``CnOcr.ocr_for_single_line(img_fp)`内部其实都是调用的函数`CnOcr.ocr_for_single_lines(img_list)`
**函数说明**
- 输入参数` img_list`: 为一个`list`;其中每个元素可以是需要识别的图片文件路径(如下例);或者是已经从图片文件中读入的数组,类型可以为 `torch.Tensor``np.ndarray`,取值应该是`[0,255]`的整数,维数应该是 `[height, width]` (灰度图片)或者 `[height, width, channel]``channel` 可以等于`1`(灰度图片)或者`3``RGB`格式的彩色图片)。
- 输入参数 `batch_size`: 待处理图片很多时,需要分批处理,每批图片的数量由此参数指定。默认为 `1`
- 返回值:为一个嵌套的`list`,其中的每个元素存储了对一行文字的识别结果,其中也包含了识别概率值。类似这样`[(['第', '一', '行'], 0.80), (['第', '二', '行'], 0.75), (['第', '三', '行'], 0.9)]`,其中的数字为对应的识别概率值。
**调用示例**
```python
import numpy as np
from cnocr.utils import read_img
from cnocr import CnOcr, line_split
ocr = CnOcr()
img_fp = 'examples/multi-line_cn1.png'
img = read_img(img_fp)
line_imgs = line_split(np.squeeze(img, -1), blank=True)
line_img_list = [line_img for line_img, _ in line_imgs]
res = ocr.ocr_for_single_lines(line_img_list)
print("Predicted Chars:", res)
```
更详细的使用方法,可参考 [tests/test_cnocr.py](https://github.com/breezedeus/cnocr/blob/master/tests/test_cnocr.py) 中提供的测试用例。
# 可取值:['densenet-s']
ENCODER_NAME = densenet-s
ENCODER_NAME = densenet-lite-136
# 可取值:['fc', 'gru', 'lstm']
DECODER_NAME = gru
DECODER_NAME = fclite
MODEL_NAME = $(ENCODER_NAME)-$(DECODER_NAME)
EPOCH = 41
INDEX_DIR = data
TRAIN_CONFIG_FP = examples/train_config_gpu.json
INDEX_DIR = data/output_normal
TRAIN_CONFIG_FP = docs/examples/train_config_gpu.json
train:
cnocr train -m $(MODEL_NAME) --index-dir $(INDEX_DIR) --train-config-fp $(TRAIN_CONFIG_FP)
evaluate:
python scripts/cnocr_evaluate.py --model-name $(MODEL_NAME) --model-epoch $(EPOCH) -i $(REC_DATA_ROOT_DIR)/test-part.txt --image-prefix-dir $(REC_DATA_ROOT_DIR) --batch-size 128 --gpu 1 -o evaluate/$(MODEL_NAME)-$(EPOCH)
cnocr evaluate -m $(MODEL_NAME) -i $(REC_DATA_ROOT_DIR)/test-part.txt --image-folder $(REC_DATA_ROOT_DIR) --batch-size 128 -c cuda:0 -o eval_results/$(MODEL_NAME)-$(EPOCH)
filter:
python scripts/filter_samples.py --sample_file $(REC_DATA_ROOT_DIR)/test-part.txt --badcases_file evaluate/$(MODEL_NAME)-$(EPOCH)/badcases.txt --distance_thrsh 2 -o $(REC_DATA_ROOT_DIR)/new.txt
predict:
cnocr predict -m $(MODEL_NAME) -f examples/rand_cn1.png
cnocr predict -m $(MODEL_NAME) -f docs/examples/rand_cn1.png
......
# Project information
site_name: CnOcr
site_url: https://cnocr.readthedocs.io
site_description: CnOcr 使用说明
site_author: Breezedeus
# Repository
repo_url: https://github.com/breezedeus/cnocr
repo_name: Breezedeus/CnOcr
edit_uri: "" #disables edit button
# Copyright
copyright: Copyright &copy; 2021
# Social media
extra:
social:
- icon: fontawesome/brands/github
link: https://github.com/breezedeus
- icon: fontawesome/brands/zhihu
link: https://www.zhihu.com/people/breezedeus-50
- icon: fontawesome/brands/youtube
link: https://space.bilibili.com/509307267
- icon: fontawesome/brands/twitter
link: https://twitter.com/breezedeus
# Configuration
theme:
name: material
# name: readthedocs
logo: figs/jinlong.png
favicon: figs/jinlong.ico
palette:
primary: indigo
accent: indigo
font:
text: Roboto
code: Roboto Mono
features:
- navigation.tabs
- navigation.expand
icon:
repo: fontawesome/brands/github
# Extensions
markdown_extensions:
- meta
- pymdownx.emoji:
emoji_index: !!python/name:materialx.emoji.twemoji
emoji_generator: !!python/name:materialx.emoji.to_svg
- admonition # alerts
- pymdownx.details # collapsible alerts
- pymdownx.superfences # nest code and content inside alerts
- attr_list # add HTML and CSS to Markdown elements
- pymdownx.inlinehilite # inline code highlights
- pymdownx.keys # show keystroke symbols
- pymdownx.snippets # insert content from other files
- pymdownx.tabbed # content tabs
- footnotes
- def_list
- pymdownx.arithmatex: # mathjax
generic: true
- pymdownx.tasklist:
custom_checkbox: true
clickable_checkbox: false
- codehilite
- pymdownx.highlight:
use_pygments: true
- toc:
toc_depth: 4
# Plugins
plugins:
- search
- macros
- mkdocstrings:
default_handler: python
handlers:
python:
rendering:
show_root_heading: false
show_source: true
show_category_heading: true
watch:
- cnocr
# Extra CSS
extra_css:
- static/css/custom.css
# Extra JS
extra_javascript:
- https://cdnjs.cloudflare.com/ajax/libs/tablesort/5.2.1/tablesort.min.js
- https://polyfill.io/v3/polyfill.min.js?features=es6
- https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js
# Page tree
nav:
- Home: index.md
- Docs:
- 场景文字识别技术介绍: std_ocr.md
- 安装: install.md
- 使用方法: usage.md
- 命令行工具: command.md
- 在线 Demo: demo.md
- 自带模型: models.md
- 文本检测CnStd + 文字识别CnOcr: cnstd_cnocr.md
- 模型训练: train.md
- QQ交流群: contact.md
- 常见问题(FAQ): faq.md
- RELEASE 文档: RELEASE.md
- Python API:
- CnOcr 类: cnocr/cn_ocr.md
\ No newline at end of file
click
tqdm
torch>=1.7.0
torchvision
torch>=1.8.0
torchvision>=0.9.0
numpy
pytorch-lightning
pillow>=5.3.0
python-Levenshtein
......@@ -31,6 +31,7 @@ pyasn1-modules==0.2.8 # via google-auth
pyasn1==0.4.8 # via pyasn1-modules, rsa
pydeprecate==0.3.1 # via pytorch-lightning
pyparsing==2.4.7 # via packaging
python-levenshtein==0.12.0 # via -r requirements.in
pytorch-lightning==1.4.4 # via -r requirements.in
pyyaml==5.4.1 # via pytorch-lightning
requests-oauthlib==1.3.0 # via google-auth-oauthlib
......
......@@ -39,11 +39,12 @@ exec(
required = [
"click",
"tqdm",
"torch>=1.7.0",
"torchvision",
"torch>=1.8.0",
"torchvision>=0.9.0",
'numpy',
"pytorch-lightning",
"pillow>=5.3.0",
"python-Levenshtein",
]
extras_require = {
"dev": ["pip-tools", "pytest", "python-Levenshtein"],
......
......@@ -33,7 +33,7 @@ from cnocr.consts import NUMBERS, AVAILABLE_MODELS
from cnocr.line_split import line_split
root_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
example_dir = os.path.join(root_dir, 'examples')
example_dir = os.path.join(root_dir, 'docs/examples')
CNOCR = CnOcr(model_name='densenet-s-fc', model_epoch=None)
SINGLE_LINE_CASES = [
......
......@@ -9,7 +9,7 @@ from torchvision import transforms
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.insert(1, os.path.dirname(os.path.abspath(__file__)))
EXAMPLE_DIR = Path(__file__).parent.parent / 'examples'
EXAMPLE_DIR = Path(__file__).parent.parent / 'docs/examples'
INDEX_DIR = Path(__file__).parent.parent / 'data/test'
from cnocr.utils import save_img
......
......@@ -3,124 +3,171 @@ import os
import sys
from copy import deepcopy
import pytest
import mxnet as mx
from mxnet import nd
import torch
from torch import nn
from torchvision.models import (
resnet50,
resnet34,
resnet18,
mobilenet_v3_large,
mobilenet_v3_small,
shufflenet_v2_x1_0,
shufflenet_v2_x1_5,
shufflenet_v2_x2_0,
)
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.insert(1, os.path.dirname(os.path.abspath(__file__)))
from cnocr.consts import EMB_MODEL_TYPES, SEQ_MODEL_TYPES
from cnocr.utils import set_logger
from cnocr.hyperparams.cn_hyperparams import CnHyperparams
from cnocr.symbols.densenet import _make_dense_layer, DenseNet, cal_num_params
from cnocr.symbols.crnn import (
CRnn,
pipline,
gen_network,
get_infer_shape,
crnn_lstm,
crnn_lstm_lite,
)
from cnocr.utils import set_logger, get_model_size
from cnocr.consts import IMG_STANDARD_HEIGHT, ENCODER_CONFIGS, DECODER_CONFIGS
from cnocr.models.densenet import DenseNet, DenseNetLite
from cnocr.models.mobilenet import gen_mobilenet_v3
logger = set_logger('info')
HP = CnHyperparams()
logger = set_logger('info')
def test_dense_layer():
x = nd.random.randn(128, 64, 32, 280)
net = _make_dense_layer(64, 2, 0.1)
net.initialize()
y = net(x)
logger.info(net)
logger.info(y.shape)
def test_conv():
conv = nn.Conv2d(32, 32, kernel_size=5, stride=1, padding=2, bias=False)
input = torch.rand(1, 32, 10, 4)
res = conv(input)
logger.info(res.shape)
def test_densenet():
width = 280
x = nd.random.randn(128, 64, 32, width)
layer_channels = (64, 128, 256, 512)
for shorter in (False, True):
net = DenseNet(layer_channels, shorter=shorter)
net.initialize()
y = net(x)
img = torch.rand(4, 1, IMG_STANDARD_HEIGHT, width)
net = DenseNet(32, [2, 2, 2, 2], 64)
net.eval()
logger.info(net)
logger.info(f'model size: {get_model_size(net)}') # 406464
logger.info(img.shape)
res = net(img)
logger.info(res.shape)
assert tuple(res.shape) == (4, 128, 4, 35)
net = DenseNet(32, [1, 1, 1, 4], 64)
net.eval()
logger.info(net)
logger.info(f'model size: {get_model_size(net)}') # 301440
logger.info(img.shape)
res = net(img)
logger.info(res.shape)
# assert tuple(res.shape) == (4, 100, 4, 35)
#
# net = DenseNet(32, [1, 1, 2, 2], 64)
# net.eval()
# logger.info(net)
# logger.info(f'model size: {get_model_size(net)}') # 243616
# logger.info(img.shape)
# res = net(img)
# logger.info(res.shape)
# assert tuple(res.shape) == (4, 116, 4, 35)
#
# net = DenseNet(32, [1, 2, 2, 2], 64)
# net.eval()
# logger.info(net)
# logger.info(f'model size: {get_model_size(net)}') # 230680
# logger.info(img.shape)
# res = net(img)
# logger.info(res.shape)
# assert tuple(res.shape) == (4, 124, 4, 35)
#
# net = DenseNet(32, [1, 1, 2, 4], 64)
# net.eval()
# logger.info(net)
# logger.info(f'model size: {get_model_size(net)}') # 230680
# logger.info(img.shape)
# res = net(img)
# logger.info(res.shape)
# assert tuple(res.shape) == (4, 180, 4, 35)
def test_densenet_lite():
width = 280
img = torch.rand(4, 1, IMG_STANDARD_HEIGHT, width)
# net = DenseNetLite(32, [2, 2, 2], 64)
# net.eval()
# logger.info(net)
# logger.info(f'model size: {get_model_size(net)}') # 302976
# logger.info(img.shape)
# res = net(img)
# logger.info(res.shape)
# assert tuple(res.shape) == (4, 128, 2, 35)
# net = DenseNetLite(32, [2, 1, 1], 64)
# net.eval()
# logger.info(net)
# logger.info(f'model size: {get_model_size(net)}') # 197952
# logger.info(img.shape)
# res = net(img)
# logger.info(res.shape)
# assert tuple(res.shape) == (4, 80, 2, 35)
net = DenseNetLite(32, [1, 3, 4], 64)
net.eval()
logger.info(net)
logger.info(y.shape) # (128, 512, 1, 70) or (128, 512, 1, 35)
assert y.shape[2] == 1
expected_seq_len = width // 8 if shorter else width // 4
assert y.shape[3] == expected_seq_len
logger.info('number of parameters: %d', cal_num_params(net)) # 1748224
def test_crnn():
_hp = deepcopy(HP)
_hp.set_seq_length(_hp.img_width // 4)
x = nd.random.randn(128, 64, 32, 280)
layer_channels_list = [(64, 128, 256, 512), (32, 64, 128, 256)]
for layer_channels in layer_channels_list:
densenet = DenseNet(layer_channels)
crnn = CRnn(_hp, densenet)
crnn.initialize()
y = crnn(x)
logger.info(
'output shape: %s', y.shape
) # res: `(sequence_length, batch_size, 2*num_hidden)`
assert y.shape == (_hp.seq_length, _hp.batch_size, 2 * _hp.num_hidden)
logger.info('number of parameters: %d', cal_num_params(crnn))
def test_crnn_lstm():
hp = deepcopy(HP)
hp.set_seq_length(hp.img_width // 8)
data = mx.sym.Variable('data', shape=(128, 1, 32, 280))
pred = crnn_lstm(HP, data)
pred_shape = pred.infer_shape()[1][0]
logger.info('shape of pred: %s', pred_shape)
assert pred_shape == (hp.seq_length, hp.batch_size, 2 * hp.num_hidden)
def test_crnn_lstm_lite():
hp = deepcopy(HP)
width = hp.img_width # 280
data = mx.sym.Variable('data', shape=(128, 1, 32, width))
for shorter in (False, True):
pred = crnn_lstm_lite(HP, data, shorter=shorter)
pred_shape = pred.infer_shape()[1][0]
logger.info('shape of pred: %s', pred_shape)
seq_len = hp.img_width // 8 if shorter else hp.img_width // 4 - 1
assert pred_shape == (seq_len, hp.batch_size, 2 * hp.num_hidden)
def test_pipline():
hp = deepcopy(HP)
hp.set_seq_length(hp.img_width // 4)
hp._loss_type = None # infer mode
layer_channels_list = [(64, 128, 256, 512), (32, 64, 128, 256)]
for layer_channels in layer_channels_list:
densenet = DenseNet(layer_channels)
crnn = CRnn(hp, densenet)
data = mx.sym.Variable('data', shape=(128, 1, 32, 280))
pred = pipline(crnn, hp, data)
pred_shape = pred.infer_shape()[1][0]
logger.info('shape of pred: %s', pred_shape)
assert pred_shape == (hp.batch_size * hp.seq_length, hp.num_classes)
logger.info(f'model size: {get_model_size(net)}') # 186672
logger.info(img.shape)
res = net(img)
logger.info(res.shape)
assert tuple(res.shape) == (4, 200, 2, 35)
net = DenseNetLite(32, [1, 3, 6], 64)
net.eval()
logger.info(net)
logger.info(f'model size: {get_model_size(net)}') # 186672
logger.info(img.shape)
res = net(img)
logger.info(res.shape)
assert tuple(res.shape) == (4, 264, 2, 35)
# net = DenseNetLite(32, [1, 2, 2], 64)
# net.eval()
# logger.info(net)
# logger.info(f'model size: {get_model_size(net)}') #
# logger.info(img.shape)
# res = net(img)
# logger.info(res.shape)
# assert tuple(res.shape) == (4, 120, 2, 35)
def test_mobilenet():
width = 280
img = torch.rand(4, 1, IMG_STANDARD_HEIGHT, width)
net = gen_mobilenet_v3('tiny')
net.eval()
logger.info(net)
res = net(img)
logger.info(f'model size: {get_model_size(net)}') # 186672
logger.info(res.shape)
assert tuple(res.shape) == (4, 192, 2, 35)
net = gen_mobilenet_v3('small')
net.eval()
logger.info(net)
res = net(img)
logger.info(f'model size: {get_model_size(net)}') # 186672
logger.info(res.shape)
assert tuple(res.shape) == (4, 192, 2, 35)
MODEL_NAMES = []
for emb_model in EMB_MODEL_TYPES:
for seq_model in SEQ_MODEL_TYPES:
for emb_model in ENCODER_CONFIGS:
for seq_model in DECODER_CONFIGS:
MODEL_NAMES.append('%s-%s' % (emb_model, seq_model))
@pytest.mark.parametrize(
'model_name', MODEL_NAMES
)
def test_gen_networks(model_name):
logger.info('model_name: %s', model_name)
network, hp = gen_network(model_name, HP)
shape_dict = get_infer_shape(network, HP)
logger.info('shape_dict: %s', shape_dict)
assert shape_dict['pred_fc_output'] == (
hp.batch_size * hp.seq_length,
hp.num_classes,
)
# @pytest.mark.parametrize(
# 'model_name', MODEL_NAMES
# )
# def test_gen_networks(model_name):
# logger.info('model_name: %s', model_name)
# network, hp = gen_network(model_name, HP)
# shape_dict = get_infer_shape(network, HP)
# logger.info('shape_dict: %s', shape_dict)
# assert shape_dict['pred_fc_output'] == (
# hp.batch_size * hp.seq_length,
# hp.num_classes,
# )
......@@ -36,65 +36,3 @@ def test_crnn():
crnn = OcrModel(net, vocab=ENG_LETTERS, lstm_features=512, rnn_units=128)
res2 = crnn(img)
print(res2)
def test_crnn_for_variable_length():
vocab, letter2id = read_charset(VOCAB_FP)
net = DenseNet(32, [2, 2, 2, 2], 64)
crnn = OcrModel(net, vocab=vocab, lstm_features=512, rnn_units=128)
crnn.eval()
model_fp = VOCAB_FP.parent / 'models/last.ckpt'
if model_fp.exists():
print(f'load model params from {model_fp}')
load_model_params(crnn, model_fp)
width = 280
img1 = torch.rand(1, IMG_STANDARD_HEIGHT, width)
img2 = torch.rand(1, IMG_STANDARD_HEIGHT, width // 2)
img3 = torch.rand(1, IMG_STANDARD_HEIGHT, width * 2)
imgs = pad_img_seq([img1, img2, img3])
input_lengths = torch.Tensor([width, width // 2, width * 2])
out = crnn(
imgs, input_lengths=input_lengths, return_model_output=True, return_preds=True,
)
print(out['preds'])
padded = torch.zeros((3, 1, IMG_STANDARD_HEIGHT, 50))
imgs2 = torch.cat((imgs, padded), dim=-1)
out2 = crnn(
imgs2, input_lengths=input_lengths, return_model_output=True, return_preds=True,
)
print(out2['preds'])
# breakpoint()
def test_crnn_for_variable_length2():
vocab, letter2id = read_charset(VOCAB_FP)
net = DenseNet(32, [2, 2, 2, 2], 64)
crnn = OcrModel(net, vocab=vocab, lstm_features=512, rnn_units=128)
crnn.eval()
model_fp = VOCAB_FP.parent / 'models/last.ckpt'
if model_fp.exists():
print(f'load model params from {model_fp}')
load_model_params(crnn, model_fp)
img_fps = ('helloworld.jpg', 'helloworld-ch.jpg')
imgs = []
input_lengths = []
for fp in img_fps:
img = read_img(VOCAB_FP.parent / 'examples' / fp)
img = rescale_img(img)
input_lengths.append(img.shape[2])
imgs.append(normalize_img_array(img))
imgs = pad_img_seq(imgs)
input_lengths = torch.Tensor(input_lengths)
out = crnn(
imgs, input_lengths=input_lengths, return_model_output=True, return_preds=True,
)
print(out['preds'])
padded = torch.zeros((2, 1, IMG_STANDARD_HEIGHT, 80))
imgs2 = torch.cat((imgs, padded), dim=-1)
out2 = crnn(
imgs2, input_lengths=input_lengths, return_model_output=True, return_preds=True,
)
print(out2['preds'])
# breakpoint()
......@@ -9,7 +9,7 @@ from torchvision import transforms
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.insert(1, os.path.dirname(os.path.abspath(__file__)))
EXAMPLE_DIR = Path(__file__).parent.parent / 'examples'
EXAMPLE_DIR = Path(__file__).parent.parent / 'docs/examples'
INDEX_DIR = Path(__file__).parent.parent / 'data/test'
IMAGE_DIR = Path(__file__).parent.parent / 'data/images'
......
......@@ -7,7 +7,7 @@ from mxnet.gluon.utils import download
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.insert(1, os.path.dirname(os.path.abspath(__file__)))
EXAMPLE_DIR = Path(__file__).parent.parent / 'examples'
EXAMPLE_DIR = Path(__file__).parent.parent / 'docs/examples'
from cnocr.utils import check_context, read_img
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册