提交 b360e4da 编写于 作者: B breezedeus

support for input candidates

上级 b08796f7
......@@ -31,8 +31,11 @@ from cnocr.utils import (
get_model_file,
read_charset,
check_model_name,
check_context, read_img,
load_model_params, rescale_img, pad_img_seq,
check_context,
read_img,
load_model_params,
rescale_img,
pad_img_seq,
)
from .data_utils.aug import NormalizeAug
from .line_split import line_split
......@@ -43,7 +46,9 @@ logger = logging.getLogger(__name__)
def gen_model(model_name, vocab):
check_model_name(model_name)
if not model_name.startswith('densenet-s'):
logger.warning('only "densenet-s" is supported now, use "densenet-s-fc" by default')
logger.warning(
'only "densenet-s" is supported now, use "densenet-s-fc" by default'
)
model_name = 'densenet-s-fc'
model = OcrModel.from_name(model_name, vocab)
return model
......@@ -81,11 +86,9 @@ class CnOcr(object):
root = os.path.join(root, MODEL_VERSION)
self._model_dir = os.path.join(root, self._model_name)
# self._assert_and_prepare_model_files()
self._vocab, self._inv_alph_dict = read_charset(
VOCAB_FP
)
self._vocab, self._letter2id = read_charset(VOCAB_FP)
self._cand_alph_idx = None
self._candidates = None
self.set_cand_alphabet(cand_alphabet)
self.context = context
......@@ -111,9 +114,13 @@ class CnOcr(object):
def _get_module(self, context):
from glob import glob
fps = glob('%s/%s*.ckpt' % (self._model_dir, self._model_file_prefix))
if len(fps) > 1:
raise ValueError('multiple ckpt files are found in %s, not sure which one should be used' % self._model_dir)
raise ValueError(
'multiple ckpt files are found in %s, not sure which one should be used'
% self._model_dir
)
elif len(fps) < 1:
raise FileNotFoundError('no ckpt file is found in %s' % self._model_dir)
......@@ -131,12 +138,24 @@ class CnOcr(object):
:return: None
"""
if cand_alphabet is None:
self._cand_alph_idx = None
self._candidates = None
else:
self._cand_alph_idx = [self._inv_alph_dict[word] for word in cand_alphabet]
self._cand_alph_idx.sort()
cand_alphabet = [word if word != ' ' else '<space>' for word in cand_alphabet]
excluded = set(
[word for word in cand_alphabet if word not in self._letter2id]
)
if excluded:
logger.warning(
'chars in candidates are not in the vocab, ignoring them: %s'
% excluded
)
candidates = [word for word in cand_alphabet if word in self._letter2id]
self._candidates = None if len(candidates) == 0 else candidates
logger.info('candidate chars: %s' % self._candidates)
def ocr(self, img_fp: Union[str, Path, torch.Tensor, np.ndarray]) -> List[Tuple[List[str], float]]:
def ocr(
self, img_fp: Union[str, Path, torch.Tensor, np.ndarray]
) -> List[Tuple[List[str], float]]:
"""
:param img_fp: image file path; or color image mx.nd.NDArray or np.ndarray,
with shape (height, width, 3), and the channels should be RGB formatted.
......@@ -162,7 +181,9 @@ class CnOcr(object):
line_chars_list = self.ocr_for_single_lines(line_img_list)
return line_chars_list
def ocr_for_single_line(self, img_fp: Union[str, Path, torch.Tensor, np.ndarray]) -> Tuple[List[str], float]:
def ocr_for_single_line(
self, img_fp: Union[str, Path, torch.Tensor, np.ndarray]
) -> Tuple[List[str], float]:
"""
Recognize characters from an image with only one-line characters.
:param img_fp: image file path; or image mx.nd.NDArray or np.ndarray,
......@@ -181,7 +202,9 @@ class CnOcr(object):
res = self.ocr_for_single_lines([img])
return res[0]
def ocr_for_single_lines(self, img_list: List[Union[torch.Tensor, np.ndarray]]) -> List[Tuple[List[str], float]]:
def ocr_for_single_lines(
self, img_list: List[Union[torch.Tensor, np.ndarray]]
) -> List[Tuple[List[str], float]]:
"""
Batch recognize characters from a list of one-line-characters images.
:param img_list: list of images, in which each element should be a line image array,
......@@ -206,7 +229,9 @@ class CnOcr(object):
return res
def _preprocess_img_array(self, img: Union[torch.Tensor, np.ndarray]) -> torch.Tensor:
def _preprocess_img_array(
self, img: Union[torch.Tensor, np.ndarray]
) -> torch.Tensor:
"""
:param img: image array with type torch.Tensor or np.ndarray,
with shape [height, width] or [channel, height, width].
......@@ -228,5 +253,7 @@ class CnOcr(object):
img_lengths = torch.tensor([img.shape[2] for img in img_list])
imgs = pad_img_seq(img_list)
with torch.no_grad():
out = self._mod(imgs, img_lengths, return_preds=True)
out = self._mod(
imgs, img_lengths, candidates=self._candidates, return_preds=True
)
return out
# coding: utf-8
from typing import Tuple, Dict, Any, Optional, List
from typing import Tuple, Dict, Any, Optional, List, Union
from copy import deepcopy
import numpy as np
......@@ -147,9 +147,20 @@ class OcrModel(nn.Module):
x: torch.Tensor,
input_lengths: torch.Tensor,
target: Optional[List[str]] = None,
return_model_output: bool = False,
candidates: Optional[Union[str, List[str]]] = None,
return_logits: bool = False,
return_preds: bool = False,
) -> Dict[str, Any]:
"""
:param x: [B, 1, H, W]; 一组padding后的图片
:param input_lengths: shape: [B];每张图片padding前的真实长度(宽度)
:param target: 真实的字符串
:param candidates: None or candidate strs; 允许的候选字符集合
:param return_logits: 是否返回预测的logits值
:param return_preds: 是否返回预测的字符串
:return: 预测结果
"""
features = self.encoder(x)
input_lengths = input_lengths // self.encoder.compress_ratio
# B x C x H x W --> B x C*H x W --> B x W x C*H
......@@ -160,9 +171,10 @@ class OcrModel(nn.Module):
logits = self._decode(features_seq, input_lengths)
logits = self.linear(logits)
logits = self._mask_by_candidates(logits, candidates)
out: Dict[str, Any] = {}
if return_model_output:
if return_logits:
out["logits"] = logits
if target is None or return_preds:
......@@ -191,6 +203,25 @@ class OcrModel(nn.Module):
)
return logits
def _mask_by_candidates(
self, logits: torch.Tensor, candidates: Optional[Union[str, List[str]]]
):
if candidates is None:
return logits
_candidates = [self.letter2id[word] for word in candidates]
_candidates.sort()
_candidates = torch.tensor(_candidates, dtype=torch.int64)
candidates = torch.zeros(
(len(self.vocab) + 1,), dtype=torch.bool, device=logits.device
)
candidates[_candidates] = True
candidates[-1] = True # 间隔符号/填充符号,必须为真
candidates = candidates.unsqueeze(0).unsqueeze(0) # 1 x 1 x (vocab_size+1)
logits.masked_fill_(~candidates, -100.0)
return logits
def _compute_loss(
self,
model_output: torch.Tensor,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册