提交 831c529b 编写于 作者: B breezedeus

optimize `recognize()`

上级 b145462e
......@@ -18,5 +18,6 @@
# under the License.
from .consts import MODEL_VERSION, AVAILABLE_MODELS, NUMBERS, ENG_LETTERS
from .cn_ocr import CnOcr, gen_model
from .cn_ocr import CnOcr
from .recognizer import gen_model
from .line_split import line_split
......@@ -89,7 +89,7 @@ class BaseRecLabelDecode(object):
)
candidates = [word for word in cand_alphabet if word in self.dict]
self._candidates = None if len(candidates) == 0 else candidates
logger.info('candidate chars: %s' % self._candidates)
logger.debug('candidate chars: %s' % self._candidates)
def add_special_char(self, dict_character):
return dict_character
......
......@@ -20,13 +20,14 @@
import os
import logging
from typing import Union, Optional, Collection
from typing import Union, Optional, Collection, List, Tuple
from pathlib import Path
import math
import numpy as np
from PIL import Image
from ..utils import resize_img, data_dir, get_model_file
from ..utils import resize_img, data_dir, get_model_file, read_img
from ..recognizer import Recognizer
from .postprocess import build_post_process
from .utility import (
......@@ -48,16 +49,12 @@ class PPRecognizer(Recognizer):
cand_alphabet: Optional[Union[Collection, str]] = None,
model_fp: Optional[str] = None,
root: Union[str, Path] = data_dir(),
# rec_model_dir,
rec_image_shape="3, 32, 320",
use_space_char=True,
**kwargs
):
self.rec_image_shape = [int(v) for v in rec_image_shape.split(",")]
rec_batch_num = 6
self.rec_batch_num = rec_batch_num
rec_algorithm = ('CRNN',)
self.rec_algorithm = rec_algorithm
self.rec_algorithm = 'CRNN'
self._model_name = model_name
self._model_backend = 'onnx'
......@@ -69,30 +66,6 @@ class PPRecognizer(Recognizer):
'use_space_char': use_space_char,
'cand_alphabet': cand_alphabet,
}
# if self.rec_algorithm == "SRN":
# postprocess_params = {
# 'name': 'SRNLabelDecode',
# "character_dict_path": args.rec_char_dict_path,
# "use_space_char": args.use_space_char
# }
# elif self.rec_algorithm == "RARE":
# postprocess_params = {
# 'name': 'AttnLabelDecode',
# "character_dict_path": args.rec_char_dict_path,
# "use_space_char": args.use_space_char
# }
# elif self.rec_algorithm == 'NRTR':
# postprocess_params = {
# 'name': 'NRTRLabelDecode',
# "character_dict_path": args.rec_char_dict_path,
# "use_space_char": args.use_space_char
# }
# elif self.rec_algorithm == "SAR":
# postprocess_params = {
# 'name': 'SARLabelDecode',
# "character_dict_path": args.rec_char_dict_path,
# "use_space_char": args.use_space_char
# }
self.postprocess_op = build_post_process(postprocess_params)
(
self.predictor,
......@@ -102,28 +75,6 @@ class PPRecognizer(Recognizer):
) = create_predictor(self._model_fp, 'rec', logger)
self.use_onnx = True
# self.benchmark = args.benchmark
# self.use_onnx = args.use_onnx
# if args.benchmark:
# import auto_log
# pid = os.getpid()
# gpu_id = utility.get_infer_gpuid()
# self.autolog = auto_log.AutoLogger(
# model_name="rec",
# model_precision=args.precision,
# batch_size=args.rec_batch_num,
# data_shape="dynamic",
# save_path=None, #args.save_log_path,
# inference_config=self.config,
# pids=pid,
# process_name=None,
# gpu_ids=gpu_id if args.use_gpu else None,
# time_keys=[
# 'preprocess_time', 'inference_time', 'postprocess_time'
# ],
# warmup=0,
# logger=logger)
def _assert_and_prepare_model_files(self, model_fp, root):
if model_fp is not None and not os.path.isfile(model_fp):
raise FileNotFoundError('can not find model file %s' % model_fp)
......@@ -154,15 +105,6 @@ class PPRecognizer(Recognizer):
"""
imgC, imgH, imgW = self.rec_image_shape
# if self.rec_algorithm == 'NRTR':
# img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
# # return padding_im
# image_pil = Image.fromarray(np.uint8(img))
# img = image_pil.resize([100, 32], Image.ANTIALIAS)
# img = np.array(img)
# norm_img = np.expand_dims(img, -1)
# norm_img = norm_img.transpose((2, 0, 1))
# return norm_img.astype(np.float32) / 128. - 1.
assert imgC == img.shape[2]
imgW = int((32 * max_wh_ratio))
......@@ -188,114 +130,16 @@ class PPRecognizer(Recognizer):
padding_im[:, :, 0:resized_w] = resized_image
return padding_im
# def resize_norm_img_srn(self, img, image_shape):
# imgC, imgH, imgW = image_shape
#
# img_black = np.zeros((imgH, imgW))
# im_hei = img.shape[0]
# im_wid = img.shape[1]
#
# if im_wid <= im_hei * 1:
# img_new = cv2.resize(img, (imgH * 1, imgH))
# elif im_wid <= im_hei * 2:
# img_new = cv2.resize(img, (imgH * 2, imgH))
# elif im_wid <= im_hei * 3:
# img_new = cv2.resize(img, (imgH * 3, imgH))
# else:
# img_new = cv2.resize(img, (imgW, imgH))
#
# img_np = np.asarray(img_new)
# img_np = cv2.cvtColor(img_np, cv2.COLOR_BGR2GRAY)
# img_black[:, 0:img_np.shape[1]] = img_np
# img_black = img_black[:, :, np.newaxis]
#
# row, col, c = img_black.shape
# c = 1
#
# return np.reshape(img_black, (c, row, col)).astype(np.float32)
#
# def srn_other_inputs(self, image_shape, num_heads, max_text_length):
#
# imgC, imgH, imgW = image_shape
# feature_dim = int((imgH / 8) * (imgW / 8))
#
# encoder_word_pos = np.array(range(0, feature_dim)).reshape(
# (feature_dim, 1)).astype('int64')
# gsrm_word_pos = np.array(range(0, max_text_length)).reshape(
# (max_text_length, 1)).astype('int64')
#
# gsrm_attn_bias_data = np.ones((1, max_text_length, max_text_length))
# gsrm_slf_attn_bias1 = np.triu(gsrm_attn_bias_data, 1).reshape(
# [-1, 1, max_text_length, max_text_length])
# gsrm_slf_attn_bias1 = np.tile(
# gsrm_slf_attn_bias1,
# [1, num_heads, 1, 1]).astype('float32') * [-1e9]
#
# gsrm_slf_attn_bias2 = np.tril(gsrm_attn_bias_data, -1).reshape(
# [-1, 1, max_text_length, max_text_length])
# gsrm_slf_attn_bias2 = np.tile(
# gsrm_slf_attn_bias2,
# [1, num_heads, 1, 1]).astype('float32') * [-1e9]
#
# encoder_word_pos = encoder_word_pos[np.newaxis, :]
# gsrm_word_pos = gsrm_word_pos[np.newaxis, :]
#
# return [
# encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1,
# gsrm_slf_attn_bias2
# ]
#
# def process_image_srn(self, img, image_shape, num_heads, max_text_length):
# norm_img = self.resize_norm_img_srn(img, image_shape)
# norm_img = norm_img[np.newaxis, :]
#
# [encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1, gsrm_slf_attn_bias2] = \
# self.srn_other_inputs(image_shape, num_heads, max_text_length)
#
# gsrm_slf_attn_bias1 = gsrm_slf_attn_bias1.astype(np.float32)
# gsrm_slf_attn_bias2 = gsrm_slf_attn_bias2.astype(np.float32)
# encoder_word_pos = encoder_word_pos.astype(np.int64)
# gsrm_word_pos = gsrm_word_pos.astype(np.int64)
#
# return (norm_img, encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1,
# gsrm_slf_attn_bias2)
def recognize(
self,
img_list: List[Union[str, Path, np.ndarray]],
batch_size: int = 1,
) -> List[Tuple[str, float]]:
if len(img_list) == 0:
return []
img_list = [self._prepare_img(img) for img in img_list]
# def resize_norm_img_sar(self, img, image_shape,
# width_downsample_ratio=0.25):
# imgC, imgH, imgW_min, imgW_max = image_shape
# h = img.shape[0]
# w = img.shape[1]
# valid_ratio = 1.0
# # make sure new_width is an integral multiple of width_divisor.
# width_divisor = int(1 / width_downsample_ratio)
# # resize
# ratio = w / float(h)
# resize_w = math.ceil(imgH * ratio)
# if resize_w % width_divisor != 0:
# resize_w = round(resize_w / width_divisor) * width_divisor
# if imgW_min is not None:
# resize_w = max(imgW_min, resize_w)
# if imgW_max is not None:
# valid_ratio = min(1.0, 1.0 * resize_w / imgW_max)
# resize_w = min(imgW_max, resize_w)
# resized_image = cv2.resize(img, (resize_w, imgH))
# resized_image = resized_image.astype('float32')
# # norm
# if image_shape[0] == 1:
# resized_image = resized_image / 255
# resized_image = resized_image[np.newaxis, :]
# else:
# resized_image = resized_image.transpose((2, 0, 1)) / 255
# resized_image -= 0.5
# resized_image /= 0.5
# resize_shape = resized_image.shape
# padding_im = -1.0 * np.ones((imgC, imgH, imgW_max), dtype=np.float32)
# padding_im[:, :, 0:resize_w] = resized_image
# pad_shape = padding_im.shape
#
# return padding_im, resize_shape, pad_shape, valid_ratio
#
def recognize(self, img_list):
img_num = len(img_list)
# Calculate the aspect ratio of all text bars
width_list = []
......@@ -304,11 +148,8 @@ class PPRecognizer(Recognizer):
# Sorting can speed up the recognition process
indices = np.argsort(np.array(width_list))
rec_res = [['', 0.0]] * img_num
batch_num = self.rec_batch_num
# if self.benchmark:
# self.autolog.times.start()
for beg_img_no in range(0, img_num, batch_num):
end_img_no = min(img_num, beg_img_no + batch_num)
for beg_img_no in range(0, img_num, batch_size):
end_img_no = min(img_num, beg_img_no + batch_size)
norm_img_batch = []
max_wh_ratio = 0
for ino in range(beg_img_no, end_img_no):
......@@ -322,149 +163,49 @@ class PPRecognizer(Recognizer):
)
norm_img = norm_img[np.newaxis, :]
norm_img_batch.append(norm_img)
# elif self.rec_algorithm == "SAR":
# norm_img, _, _, valid_ratio = self.resize_norm_img_sar(
# img_list[indices[ino]], self.rec_image_shape)
# norm_img = norm_img[np.newaxis, :]
# valid_ratio = np.expand_dims(valid_ratio, axis=0)
# valid_ratios = []
# valid_ratios.append(valid_ratio)
# norm_img_batch.append(norm_img)
# else:
# norm_img = self.process_image_srn(
# img_list[indices[ino]], self.rec_image_shape, 8, 25)
# encoder_word_pos_list = []
# gsrm_word_pos_list = []
# gsrm_slf_attn_bias1_list = []
# gsrm_slf_attn_bias2_list = []
# encoder_word_pos_list.append(norm_img[1])
# gsrm_word_pos_list.append(norm_img[2])
# gsrm_slf_attn_bias1_list.append(norm_img[3])
# gsrm_slf_attn_bias2_list.append(norm_img[4])
# norm_img_batch.append(norm_img[0])
norm_img_batch = np.concatenate(norm_img_batch)
norm_img_batch = norm_img_batch.copy()
# if self.benchmark:
# self.autolog.times.stamp()
if self.rec_algorithm == "SRN":
encoder_word_pos_list = np.concatenate(encoder_word_pos_list)
gsrm_word_pos_list = np.concatenate(gsrm_word_pos_list)
gsrm_slf_attn_bias1_list = np.concatenate(gsrm_slf_attn_bias1_list)
gsrm_slf_attn_bias2_list = np.concatenate(gsrm_slf_attn_bias2_list)
input_dict = dict()
input_dict[self.input_tensor.name] = norm_img_batch
outputs = self.predictor.run(self.output_tensors, input_dict)
preds = outputs[0]
inputs = [
norm_img_batch,
encoder_word_pos_list,
gsrm_word_pos_list,
gsrm_slf_attn_bias1_list,
gsrm_slf_attn_bias2_list,
]
if self.use_onnx:
input_dict = {}
input_dict[self.input_tensor.name] = norm_img_batch
outputs = self.predictor.run(self.output_tensors, input_dict)
preds = {"predict": outputs[2]}
else:
input_names = self.predictor.get_input_names()
for i in range(len(input_names)):
input_tensor = self.predictor.get_input_handle(input_names[i])
input_tensor.copy_from_cpu(inputs[i])
self.predictor.run()
outputs = []
for output_tensor in self.output_tensors:
output = output_tensor.copy_to_cpu()
outputs.append(output)
# if self.benchmark:
# self.autolog.times.stamp()
preds = {"predict": outputs[2]}
elif self.rec_algorithm == "SAR":
valid_ratios = np.concatenate(valid_ratios)
inputs = [
norm_img_batch,
valid_ratios,
]
if self.use_onnx:
input_dict = {}
input_dict[self.input_tensor.name] = norm_img_batch
outputs = self.predictor.run(self.output_tensors, input_dict)
preds = outputs[0]
else:
input_names = self.predictor.get_input_names()
for i in range(len(input_names)):
input_tensor = self.predictor.get_input_handle(input_names[i])
input_tensor.copy_from_cpu(inputs[i])
self.predictor.run()
outputs = []
for output_tensor in self.output_tensors:
output = output_tensor.copy_to_cpu()
outputs.append(output)
# if self.benchmark:
# self.autolog.times.stamp()
preds = outputs[0]
else:
if self.use_onnx:
input_dict = {}
input_dict[self.input_tensor.name] = norm_img_batch
outputs = self.predictor.run(self.output_tensors, input_dict)
preds = outputs[0]
# else:
# self.input_tensor.copy_from_cpu(norm_img_batch)
# self.predictor.run()
# outputs = []
# for output_tensor in self.output_tensors:
# output = output_tensor.copy_to_cpu()
# outputs.append(output)
# # if self.benchmark:
# # self.autolog.times.stamp()
# if len(outputs) != 1:
# preds = outputs
# else:
# preds = outputs[0]
rec_result = self.postprocess_op(preds)
for rno in range(len(rec_result)):
rec_res[indices[beg_img_no + rno]] = rec_result[rno]
# if self.benchmark:
# self.autolog.times.end(stamp=True)
return rec_res
def _prepare_img(
self, img_fp: Union[str, Path, np.ndarray]
) -> np.ndarray:
"""
# def main(args):
# image_file_list = get_image_file_list(args.image_dir)
# text_recognizer = TextRecognizer(args)
# valid_image_file_list = []
# img_list = []
#
# # warmup 2 times
# if args.warmup:
# img = np.random.uniform(0, 255, [32, 320, 3]).astype(np.uint8)
# for i in range(2):
# res = text_recognizer([img] * int(args.rec_batch_num))
#
# for image_file in image_file_list:
# img, flag = check_and_read_gif(image_file)
# if not flag:
# # img = cv2.imread(image_file)
# img = read_img(image_file, gray=False)
# if img is None:
# logger.info("error in loading image:{}".format(image_file))
# continue
# valid_image_file_list.append(image_file)
# img_list.append(img)
# try:
# rec_res, _ = text_recognizer(img_list)
#
# except Exception as E:
# logger.info(traceback.format_exc())
# logger.info(E)
# exit()
# for ino in range(len(img_list)):
# logger.info(
# "Predicts of {}:{}".format(valid_image_file_list[ino], rec_res[ino])
# )
# # if args.benchmark:
# # text_recognizer.autolog.report()
#
#
# if __name__ == "__main__":
# main(parse_args())
Args:
img_fp (Union[str, Path, np.ndarray]):
image array with type torch.Tensor or np.ndarray,
with shape [height, width] or [height, width, channel].
channel should be 1 (gray image) or 3 (color image).
Returns:
np.ndarray: with shape (height, width, 3), scale [0, 255]
"""
img = img_fp
if isinstance(img_fp, (str, Path)):
if not os.path.isfile(img_fp):
raise FileNotFoundError(img_fp)
img = read_img(img_fp, gray=False)
if len(img.shape) == 3 and img.shape[2] == 1:
# (H, W, 1) -> (H, W)
img = img.squeeze(-1)
if len(img.shape) == 2:
# (H, W) -> (H, W, 3)
img = np.array(Image.fromarray(img).convert('RGB'))
elif img.shape[2] != 3:
raise ValueError(
'only images with shape [height, width, 1] (gray images), '
'or [height, width, 3] (RGB-formated color images) are supported'
)
return img
......@@ -27,9 +27,9 @@ import numpy as np
from PIL import Image
import torch
from cnocr.consts import MODEL_VERSION, AVAILABLE_MODELS, VOCAB_FP
from cnocr.models.ocr_model import OcrModel
from cnocr.utils import (
from .consts import MODEL_VERSION, AVAILABLE_MODELS, VOCAB_FP
from .models.ocr_model import OcrModel
from .utils import (
data_dir,
get_model_file,
read_charset,
......@@ -42,7 +42,6 @@ from cnocr.utils import (
to_numpy,
)
from .data_utils.aug import NormalizeAug
from .line_split import line_split
from .models.ctc import CTCPostProcessor
logger = logging.getLogger(__name__)
......@@ -218,7 +217,7 @@ class Recognizer(object):
)
candidates = [word for word in cand_alphabet if word in self._letter2id]
self._candidates = None if len(candidates) == 0 else candidates
logger.info('candidate chars: %s' % self._candidates)
logger.debug('candidate chars: %s' % self._candidates)
# def ocr(
# self, img_fp: Union[str, Path, torch.Tensor, np.ndarray]
......@@ -308,7 +307,7 @@ class Recognizer(object):
self,
img_list: List[Union[str, Path, torch.Tensor, np.ndarray]],
batch_size: int = 1,
) -> List[Tuple[List[str], float]]:
) -> List[Tuple[str, float]]:
"""
Batch recognize characters from a list of one-line-characters images.
......@@ -323,8 +322,8 @@ class Recognizer(object):
batch_size: 待处理图片很多时,需要分批处理,每批图片的数量由此参数指定。默认为 `1`。
Returns:
list: list of (list of chars, prob), such as
[(['第', '一', '行'], 0.80), (['第', '二', '行'], 0.75), (['第', '三', '行'], 0.9)]
list: list of (chars, prob), such as
[('第一行', 0.80), ('第二行', 0.75), ('第三行', 0.9)]
"""
if len(img_list) == 0:
return []
......@@ -359,7 +358,7 @@ class Recognizer(object):
for line in out:
chars, prob = line
chars = [c if c != '<space>' else ' ' for c in chars]
res.append((chars, prob))
res.append((''.join(chars), prob))
return res
......
......@@ -70,11 +70,20 @@ def test_ppocr(img_fp, expected):
ocr = CNOCR
img_fp = os.path.join(example_dir, img_fp)
pred = ocr.recognize([img_fp])[0]
print_preds(pred)
assert cal_score([pred], expected) >= 0.8
img = read_img(img_fp, gray=False)
pred = ocr.recognize([img])[0]
print_preds(pred)
assert cal_score([pred], expected) >= 0.8
img = read_img(img_fp, gray=True)
pred = ocr.recognize([img])[0]
print_preds(pred)
assert cal_score([pred], expected) >= 0.8
def test_cand_alphabet():
img_fp = os.path.join(example_dir, 'hybrid.png')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册