From 28b2d43e57e76cb5a46a866dcceb4c6d66f5378f Mon Sep 17 00:00:00 2001 From: WenmuZhou Date: Mon, 7 Dec 2020 19:10:19 +0800 Subject: [PATCH] paddleocr whl adaptation dygraph --- MANIFEST.in | 7 +- doc/doc_ch/whl.md | 58 +++++++- doc/doc_en/whl_en.md | 56 +++++++- paddleocr.py | 248 +++++++++++++++++++++++++--------- setup.py | 2 +- tools/infer/predict_system.py | 33 +++-- 6 files changed, 320 insertions(+), 84 deletions(-) diff --git a/MANIFEST.in b/MANIFEST.in index 388882df..4c16c09d 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,8 +1,7 @@ include LICENSE.txt include README.md -recursive-include ppocr/utils *.txt utility.py character.py check.py -recursive-include ppocr/data/det *.py +recursive-include ppocr/utils *.txt utility.py logging.py +recursive-include ppocr/data/ *.py recursive-include ppocr/postprocess *.py -recursive-include ppocr/postprocess/lanms *.* -recursive-include tools/infer *.py +recursive-include tools/infer *.py \ No newline at end of file diff --git a/doc/doc_ch/whl.md b/doc/doc_ch/whl.md index 1b04a9a8..c51f3277 100644 --- a/doc/doc_ch/whl.md +++ b/doc/doc_ch/whl.md @@ -261,6 +261,61 @@ im_show.save('result.jpg') paddleocr --image_dir PaddleOCR/doc/imgs/11.jpg --det_model_dir {your_det_model_dir} --rec_model_dir {your_rec_model_dir} --rec_char_dict_path {your_rec_char_dict_path} --cls_model_dir {your_cls_model_dir} --use_angle_cls true --cls true ``` +### 使用网络图片或者numpy数组作为输入 + +1. 网络图片 + +代码使用 +```python +from paddleocr import PaddleOCR, draw_ocr +# Paddleocr目前支持中英文、英文、法语、德语、韩语、日语,可以通过修改lang参数进行切换 +# 参数依次为`ch`, `en`, `french`, `german`, `korean`, `japan`。 +ocr = PaddleOCR(use_angle_cls=True, lang="ch") # need to run only once to download and load model into memory +img_path = 'http://n.sinaimg.cn/ent/transform/w630h933/20171222/o111-fypvuqf1838418.jpg' +result = ocr.ocr(img_path, cls=True) +for line in result: + print(line) + +# 显示结果 +from PIL import Image +image = Image.open(img_path).convert('RGB') +boxes = [line[0] for line in result] +txts = [line[1][0] for line in result] +scores = [line[1][1] for line in result] +im_show = draw_ocr(image, boxes, txts, scores, font_path='/path/to/PaddleOCR/doc/simfang.ttf') +im_show = Image.fromarray(im_show) +im_show.save('result.jpg') +``` +命令行模式 +```bash +paddleocr --image_dir http://n.sinaimg.cn/ent/transform/w630h933/20171222/o111-fypvuqf1838418.jpg --use_angle_cls=true +``` + +2. numpy数组 +仅通过代码使用时支持numpy数组作为输入 +```python +from paddleocr import PaddleOCR, draw_ocr +# Paddleocr目前支持中英文、英文、法语、德语、韩语、日语,可以通过修改lang参数进行切换 +# 参数依次为`ch`, `en`, `french`, `german`, `korean`, `japan`。 +ocr = PaddleOCR(use_angle_cls=True, lang="ch") # need to run only once to download and load model into memory +img_path = 'PaddleOCR/doc/imgs/11.jpg' +img = cv2.imread(img_path) +# img = cv2.cvtColor(img,cv2.COLOR_BGR2GRAY), 如果你自己训练的模型支持灰度图,可以将这句话的注释取消 +result = ocr.ocr(img_path, cls=True) +for line in result: + print(line) + +# 显示结果 +from PIL import Image +image = Image.open(img_path).convert('RGB') +boxes = [line[0] for line in result] +txts = [line[1][0] for line in result] +scores = [line[1][1] for line in result] +im_show = draw_ocr(image, boxes, txts, scores, font_path='/path/to/PaddleOCR/doc/simfang.ttf') +im_show = Image.fromarray(im_show) +im_show.save('result.jpg') +``` + ## 参数说明 | 字段 | 说明 | 默认值 | @@ -285,6 +340,7 @@ paddleocr --image_dir PaddleOCR/doc/imgs/11.jpg --det_model_dir {your_det_model_ | max_text_length | 识别算法能识别的最大文字长度 | 25 | | rec_char_dict_path | 识别模型字典路径,当rec_model_dir使用方式2传参时需要修改为自己的字典路径 | ./ppocr/utils/ppocr_keys_v1.txt | | use_space_char | 是否识别空格 | TRUE | +| drop_score | 对输出按照分数(来自于识别模型)进行过滤,低于此分数的不返回 | 0.5 | | use_angle_cls | 是否加载分类模型 | FALSE | | cls_model_dir | 分类模型所在文件夹。传参方式有两种,1. None: 自动下载内置模型到 `~/.paddleocr/cls`;2.自己转换好的inference模型路径,模型路径下必须包含model和params文件 | None | | cls_image_shape | 分类算法的输入图片尺寸 | "3, 48, 192" | @@ -295,4 +351,4 @@ paddleocr --image_dir PaddleOCR/doc/imgs/11.jpg --det_model_dir {your_det_model_ | lang | 模型语言类型,目前支持 中文(ch)和英文(en) | ch | | det | 前向时使用启动检测 | TRUE | | rec | 前向时是否启动识别 | TRUE | -| cls | 前向时是否启动分类 | FALSE | +| cls | 前向时是否启动分类 (命令行模式下使用use_angle_cls控制前向是否启动分类) | FALSE | diff --git a/doc/doc_en/whl_en.md b/doc/doc_en/whl_en.md index ffbced34..c25999d4 100644 --- a/doc/doc_en/whl_en.md +++ b/doc/doc_en/whl_en.md @@ -271,6 +271,59 @@ im_show.save('result.jpg') paddleocr --image_dir PaddleOCR/doc/imgs/11.jpg --det_model_dir {your_det_model_dir} --rec_model_dir {your_rec_model_dir} --rec_char_dict_path {your_rec_char_dict_path} --cls_model_dir {your_cls_model_dir} --use_angle_cls true --cls true ``` +### Use web images or numpy array as input + +1. Web image + +Use by code +```python +from paddleocr import PaddleOCR, draw_ocr +ocr = PaddleOCR(use_angle_cls=True, lang="ch") # need to run only once to download and load model into memory +img_path = 'http://n.sinaimg.cn/ent/transform/w630h933/20171222/o111-fypvuqf1838418.jpg' +result = ocr.ocr(img_path, cls=True) +for line in result: + print(line) + +# show result +from PIL import Image +image = Image.open(img_path).convert('RGB') +boxes = [line[0] for line in result] +txts = [line[1][0] for line in result] +scores = [line[1][1] for line in result] +im_show = draw_ocr(image, boxes, txts, scores, font_path='/path/to/PaddleOCR/doc/simfang.ttf') +im_show = Image.fromarray(im_show) +im_show.save('result.jpg') +``` +Use by command line +```bash +paddleocr --image_dir http://n.sinaimg.cn/ent/transform/w630h933/20171222/o111-fypvuqf1838418.jpg --use_angle_cls=true +``` + +2. Numpy array +Support numpy array as input only when used by code + +```python +from paddleocr import PaddleOCR, draw_ocr +ocr = PaddleOCR(use_angle_cls=True, lang="ch") # need to run only once to download and load model into memory +img_path = 'PaddleOCR/doc/imgs/11.jpg' +img = cv2.imread(img_path) +# img = cv2.cvtColor(img,cv2.COLOR_BGR2GRAY), If your own training model supports grayscale images, you can uncomment this line +result = ocr.ocr(img_path, cls=True) +for line in result: + print(line) + +# show result +from PIL import Image +image = Image.open(img_path).convert('RGB') +boxes = [line[0] for line in result] +txts = [line[1][0] for line in result] +scores = [line[1][1] for line in result] +im_show = draw_ocr(image, boxes, txts, scores, font_path='/path/to/PaddleOCR/doc/simfang.ttf') +im_show = Image.fromarray(im_show) +im_show.save('result.jpg') +``` + + ## Parameter Description | Parameter | Description | Default value | @@ -295,6 +348,7 @@ paddleocr --image_dir PaddleOCR/doc/imgs/11.jpg --det_model_dir {your_det_model_ | max_text_length | The maximum text length that the recognition algorithm can recognize | 25 | | rec_char_dict_path | the alphabet path which needs to be modified to your own path when `rec_model_Name` use mode 2 | ./ppocr/utils/ppocr_keys_v1.txt | | use_space_char | Whether to recognize spaces | TRUE | +| drop_score | Filter the output by score (from the recognition model), and those below this score will not be returned | 0.5 | | use_angle_cls | Whether to load classification model | FALSE | | cls_model_dir | the classification inference model folder. There are two ways to transfer parameters, 1. None: Automatically download the built-in model to `~/.paddleocr/cls`; 2. The path of the inference model converted by yourself, the model and params files must be included in the model path | None | | cls_image_shape | image shape of classification algorithm | "3,48,192" | @@ -305,4 +359,4 @@ paddleocr --image_dir PaddleOCR/doc/imgs/11.jpg --det_model_dir {your_det_model_ | lang | The support language, now only Chinese(ch)、English(en)、French(french)、German(german)、Korean(korean)、Japanese(japan) are supported | ch | | det | Enable detction when `ppocr.ocr` func exec | TRUE | | rec | Enable recognition when `ppocr.ocr` func exec | TRUE | -| cls | Enable classification when `ppocr.ocr` func exec | FALSE | +| cls | Enable classification when `ppocr.ocr` func exec((Use use_angle_cls in command line mode to control whether to start classification in the forward direction) | FALSE | diff --git a/paddleocr.py b/paddleocr.py index d3d73cb1..17306e79 100644 --- a/paddleocr.py +++ b/paddleocr.py @@ -26,17 +26,50 @@ import requests from tqdm import tqdm from tools.infer import predict_system -from ppocr.utils.utility import initial_logger +from ppocr.utils.logging import get_logger -logger = initial_logger() +logger = get_logger() from ppocr.utils.utility import check_and_read_gif, get_image_file_list __all__ = ['PaddleOCR'] -model_params = { - 'det': 'https://paddleocr.bj.bcebos.com/ch_models/ch_det_mv3_db_infer.tar', - 'rec': - 'https://paddleocr.bj.bcebos.com/ch_models/ch_rec_mv3_crnn_enhance_infer.tar', +model_urls = { + 'det': + 'https://paddleocr.bj.bcebos.com/20-09-22/mobile/det/ch_ppocr_mobile_v1.1_det_infer.tar', + 'rec': { + 'ch': { + 'url': + 'https://paddleocr.bj.bcebos.com/20-09-22/mobile/rec/ch_ppocr_mobile_v1.1_rec_infer.tar', + 'dict_path': './ppocr/utils/ppocr_keys_v1.txt' + }, + 'en': { + 'url': + 'https://paddleocr.bj.bcebos.com/20-09-22/mobile/en/en_ppocr_mobile_v1.1_rec_infer.tar', + 'dict_path': './ppocr/utils/ic15_dict.txt' + }, + 'french': { + 'url': + 'https://paddleocr.bj.bcebos.com/20-09-22/mobile/fr/french_ppocr_mobile_v1.1_rec_infer.tar', + 'dict_path': './ppocr/utils/dict/french_dict.txt' + }, + 'german': { + 'url': + 'https://paddleocr.bj.bcebos.com/20-09-22/mobile/ge/german_ppocr_mobile_v1.1_rec_infer.tar', + 'dict_path': './ppocr/utils/dict/german_dict.txt' + }, + 'korean': { + 'url': + 'https://paddleocr.bj.bcebos.com/20-09-22/mobile/kr/korean_ppocr_mobile_v1.1_rec_infer.tar', + 'dict_path': './ppocr/utils/dict/korean_dict.txt' + }, + 'japan': { + 'url': + 'https://paddleocr.bj.bcebos.com/20-09-22/mobile/jp/japan_ppocr_mobile_v1.1_rec_infer.tar', + 'dict_path': './ppocr/utils/dict/japan_dict.txt' + } + }, + 'cls': + 'https://paddleocr.bj.bcebos.com/20-09-22/cls/ch_ppocr_mobile_v1.1_cls_infer.tar' } SUPPORT_DET_MODEL = ['DB'] @@ -54,8 +87,8 @@ def download_with_progressbar(url, save_path): progress_bar.update(len(data)) file.write(data) progress_bar.close() - if total_size_in_bytes != 0 and progress_bar.n != total_size_in_bytes: - logger.error("ERROR, something went wrong") + if total_size_in_bytes == 0 or progress_bar.n != total_size_in_bytes: + logger.error("Something went wrong while downloading models") sys.exit(0) @@ -63,7 +96,7 @@ def maybe_download(model_storage_directory, url): # using custom model if not os.path.exists(os.path.join( model_storage_directory, 'model')) or not os.path.exists( - os.path.join(model_storage_directory, 'params')): + os.path.join(model_storage_directory, 'params')): tmp_path = os.path.join(model_storage_directory, url.split('/')[-1]) print('download {} to {}'.format(url, tmp_path)) os.makedirs(model_storage_directory, exist_ok=True) @@ -84,53 +117,102 @@ def maybe_download(model_storage_directory, url): os.remove(tmp_path) -def parse_args(): +def parse_args(mMain=True, add_help=True): import argparse def str2bool(v): return v.lower() in ("true", "t", "1") - parser = argparse.ArgumentParser() - # params for prediction engine - parser.add_argument("--use_gpu", type=str2bool, default=True) - parser.add_argument("--ir_optim", type=str2bool, default=True) - parser.add_argument("--use_tensorrt", type=str2bool, default=False) - parser.add_argument("--gpu_mem", type=int, default=8000) - - # params for text detector - parser.add_argument("--image_dir", type=str) - parser.add_argument("--det_algorithm", type=str, default='DB') - parser.add_argument("--det_model_dir", type=str, default=None) - parser.add_argument("--det_max_side_len", type=float, default=960) - - # DB parmas - parser.add_argument("--det_db_thresh", type=float, default=0.3) - parser.add_argument("--det_db_box_thresh", type=float, default=0.5) - parser.add_argument("--det_db_unclip_ratio", type=float, default=2.0) - - # EAST parmas - parser.add_argument("--det_east_score_thresh", type=float, default=0.8) - parser.add_argument("--det_east_cover_thresh", type=float, default=0.1) - parser.add_argument("--det_east_nms_thresh", type=float, default=0.2) - - # params for text recognizer - parser.add_argument("--rec_algorithm", type=str, default='CRNN') - parser.add_argument("--rec_model_dir", type=str, default=None) - parser.add_argument("--rec_image_shape", type=str, default="3, 32, 320") - parser.add_argument("--rec_char_type", type=str, default='ch') - parser.add_argument("--rec_batch_num", type=int, default=30) - parser.add_argument("--max_text_length", type=int, default=25) - parser.add_argument( - "--rec_char_dict_path", - type=str, - default="./ppocr/utils/ppocr_keys_v1.txt") - parser.add_argument("--use_space_char", type=bool, default=True) - parser.add_argument("--enable_mkldnn", type=bool, default=False) - - parser.add_argument("--det", type=str2bool, default=True) - parser.add_argument("--rec", type=str2bool, default=True) - parser.add_argument("--use_zero_copy_run", type=bool, default=False) - return parser.parse_args() + if mMain: + parser = argparse.ArgumentParser(add_help=add_help) + # params for prediction engine + parser.add_argument("--use_gpu", type=str2bool, default=True) + parser.add_argument("--ir_optim", type=str2bool, default=True) + parser.add_argument("--use_tensorrt", type=str2bool, default=False) + parser.add_argument("--gpu_mem", type=int, default=8000) + + # params for text detector + parser.add_argument("--image_dir", type=str) + parser.add_argument("--det_algorithm", type=str, default='DB') + parser.add_argument("--det_model_dir", type=str, default=None) + parser.add_argument("--det_limit_side_len", type=float, default=960) + parser.add_argument("--det_limit_type", type=str, default='max') + + # DB parmas + parser.add_argument("--det_db_thresh", type=float, default=0.3) + parser.add_argument("--det_db_box_thresh", type=float, default=0.5) + parser.add_argument("--det_db_unclip_ratio", type=float, default=2.0) + + # EAST parmas + parser.add_argument("--det_east_score_thresh", type=float, default=0.8) + parser.add_argument("--det_east_cover_thresh", type=float, default=0.1) + parser.add_argument("--det_east_nms_thresh", type=float, default=0.2) + + # params for text recognizer + parser.add_argument("--rec_algorithm", type=str, default='CRNN') + parser.add_argument("--rec_model_dir", type=str, default=None) + parser.add_argument("--rec_image_shape", type=str, default="3, 32, 320") + parser.add_argument("--rec_char_type", type=str, default='ch') + parser.add_argument("--rec_batch_num", type=int, default=30) + parser.add_argument("--max_text_length", type=int, default=25) + parser.add_argument("--rec_char_dict_path", type=str, default=None) + parser.add_argument("--use_space_char", type=bool, default=True) + parser.add_argument("--drop_score", type=float, default=0.5) + + # params for text classifier + parser.add_argument("--cls_model_dir", type=str, default=None) + parser.add_argument("--cls_image_shape", type=str, default="3, 48, 192") + parser.add_argument("--label_list", type=list, default=['0', '180']) + parser.add_argument("--cls_batch_num", type=int, default=30) + parser.add_argument("--cls_thresh", type=float, default=0.9) + + parser.add_argument("--enable_mkldnn", type=bool, default=False) + parser.add_argument("--use_zero_copy_run", type=bool, default=False) + parser.add_argument("--use_pdserving", type=str2bool, default=False) + + parser.add_argument("--lang", type=str, default='ch') + parser.add_argument("--det", type=str2bool, default=True) + parser.add_argument("--rec", type=str2bool, default=True) + parser.add_argument("--use_angle_cls", type=str2bool, default=False) + return parser.parse_args() + else: + return argparse.Namespace(use_gpu=True, + ir_optim=True, + use_tensorrt=False, + gpu_mem=8000, + image_dir='', + det_algorithm='DB', + det_model_dir=None, + det_limit_side_len=960, + det_limit_type='max', + det_db_thresh=0.3, + det_db_box_thresh=0.5, + det_db_unclip_ratio=2.0, + det_east_score_thresh=0.8, + det_east_cover_thresh=0.1, + det_east_nms_thresh=0.2, + rec_algorithm='CRNN', + rec_model_dir=None, + rec_image_shape="3, 32, 320", + rec_char_type='ch', + rec_batch_num=30, + max_text_length=25, + rec_char_dict_path=None, + use_space_char=True, + drop_score=0.5, + cls_model_dir=None, + cls_image_shape="3, 48, 192", + label_list=['0', '180'], + cls_batch_num=30, + cls_thresh=0.9, + enable_mkldnn=False, + use_zero_copy_run=False, + use_pdserving=False, + lang='ch', + det=True, + rec=True, + use_angle_cls=False + ) class PaddleOCR(predict_system.TextSystem): @@ -140,18 +222,31 @@ class PaddleOCR(predict_system.TextSystem): args: **kwargs: other params show in paddleocr --help """ - postprocess_params = parse_args() + postprocess_params = parse_args(mMain=False, add_help=False) postprocess_params.__dict__.update(**kwargs) + self.use_angle_cls = postprocess_params.use_angle_cls + lang = postprocess_params.lang + assert lang in model_urls[ + 'rec'], 'param lang must in {}, but got {}'.format( + model_urls['rec'].keys(), lang) + if postprocess_params.rec_char_dict_path is None: + postprocess_params.rec_char_dict_path = model_urls['rec'][lang][ + 'dict_path'] # init model dir if postprocess_params.det_model_dir is None: postprocess_params.det_model_dir = os.path.join(BASE_DIR, 'det') if postprocess_params.rec_model_dir is None: - postprocess_params.rec_model_dir = os.path.join(BASE_DIR, 'rec') + postprocess_params.rec_model_dir = os.path.join( + BASE_DIR, 'rec/{}'.format(lang)) + if postprocess_params.cls_model_dir is None: + postprocess_params.cls_model_dir = os.path.join(BASE_DIR, 'cls') print(postprocess_params) # download model - maybe_download(postprocess_params.det_model_dir, model_params['det']) - maybe_download(postprocess_params.rec_model_dir, model_params['rec']) + maybe_download(postprocess_params.det_model_dir, model_urls['det']) + maybe_download(postprocess_params.rec_model_dir, + model_urls['rec'][lang]['url']) + maybe_download(postprocess_params.cls_model_dir, model_urls['cls']) if postprocess_params.det_algorithm not in SUPPORT_DET_MODEL: logger.error('det_algorithm must in {}'.format(SUPPORT_DET_MODEL)) @@ -166,7 +261,7 @@ class PaddleOCR(predict_system.TextSystem): # init det_model and rec_model super().__init__(postprocess_params) - def ocr(self, img, det=True, rec=True): + def ocr(self, img, det=True, rec=True, cls=False): """ ocr with paddleocr args: @@ -175,7 +270,16 @@ class PaddleOCR(predict_system.TextSystem): rec: use text recognition or not, if false, only det will be exec. default is True """ assert isinstance(img, (np.ndarray, list, str)) + if isinstance(img, list) and det == True: + logger.error('When input a list of images, det must be false') + exit(0) + + self.use_angle_cls = cls if isinstance(img, str): + # download net image + if img.startswith('http'): + download_with_progressbar(img, 'tmp.jpg') + img = 'tmp.jpg' image_file = img img, flag = check_and_read_gif(image_file) if not flag: @@ -183,6 +287,8 @@ class PaddleOCR(predict_system.TextSystem): if img is None: logger.error("error in loading image:{}".format(image_file)) return None + if isinstance(img, np.ndarray) and len(img.shape) == 2: + img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) if det and rec: dt_boxes, rec_res = self.__call__(img) return [[box.tolist(), res] for box, res in zip(dt_boxes, rec_res)] @@ -194,20 +300,34 @@ class PaddleOCR(predict_system.TextSystem): else: if not isinstance(img, list): img = [img] + if self.use_angle_cls: + img, cls_res, elapse = self.text_classifier(img) + if not rec: + return cls_res rec_res, elapse = self.text_recognizer(img) return rec_res def main(): - # for com - args = parse_args() - image_file_list = get_image_file_list(args.image_dir) + # for cmd + args = parse_args(mMain=True) + image_dir = args.image_dir + if image_dir.startswith('http'): + download_with_progressbar(image_dir, 'tmp.jpg') + image_file_list = ['tmp.jpg'] + else: + image_file_list = get_image_file_list(args.image_dir) if len(image_file_list) == 0: logger.error('no images find in {}'.format(args.image_dir)) return - ocr_engine = PaddleOCR() + + ocr_engine = PaddleOCR(**(args.__dict__)) for img_path in image_file_list: - print(img_path) - result = ocr_engine.ocr(img_path, det=args.det, rec=args.rec) - for line in result: - print(line) \ No newline at end of file + logger.info('{}{}{}'.format('*' * 10, img_path, '*' * 10)) + result = ocr_engine.ocr(img_path, + det=args.det, + rec=args.rec, + cls=args.use_angle_cls) + if result is not None: + for line in result: + logger.info(line) diff --git a/setup.py b/setup.py index 6b503ce3..bef6dbbf 100644 --- a/setup.py +++ b/setup.py @@ -32,7 +32,7 @@ setup( package_dir={'paddleocr': ''}, include_package_data=True, entry_points={"console_scripts": ["paddleocr= paddleocr.paddleocr:main"]}, - version='0.0.3', + version='2.0', install_requires=requirements, license='Apache License 2.0', description='Awesome OCR toolkits based on PaddlePaddle (8.6M ultra-lightweight pre-trained model, support training and deployment among server, mobile, embeded and IoT devices', diff --git a/tools/infer/predict_system.py b/tools/infer/predict_system.py index ae660fde..07dfc216 100755 --- a/tools/infer/predict_system.py +++ b/tools/infer/predict_system.py @@ -13,6 +13,7 @@ # limitations under the License. import os import sys + __dir__ = os.path.dirname(os.path.abspath(__file__)) sys.path.append(__dir__) sys.path.append(os.path.abspath(os.path.join(__dir__, '../..'))) @@ -30,12 +31,15 @@ from ppocr.utils.utility import get_image_file_list, check_and_read_gif from ppocr.utils.logging import get_logger from tools.infer.utility import draw_ocr_box_txt +logger = get_logger() + class TextSystem(object): def __init__(self, args): self.text_detector = predict_det.TextDetector(args) self.text_recognizer = predict_rec.TextRecognizer(args) self.use_angle_cls = args.use_angle_cls + self.drop_score = args.drop_score if self.use_angle_cls: self.text_classifier = predict_cls.TextClassifier(args) @@ -81,7 +85,8 @@ class TextSystem(object): def __call__(self, img): ori_im = img.copy() dt_boxes, elapse = self.text_detector(img) - logger.info("dt_boxes num : {}, elapse : {}".format(len(dt_boxes), elapse)) + logger.info("dt_boxes num : {}, elapse : {}".format( + len(dt_boxes), elapse)) if dt_boxes is None: return None, None img_crop_list = [] @@ -99,9 +104,16 @@ class TextSystem(object): len(img_crop_list), elapse)) rec_res, elapse = self.text_recognizer(img_crop_list) - logger.info("rec_res num : {}, elapse : {}".format(len(rec_res), elapse)) + logger.info("rec_res num : {}, elapse : {}".format( + len(rec_res), elapse)) # self.print_draw_crop_rec_res(img_crop_list, rec_res) - return dt_boxes, rec_res + filter_boxes, filter_rec_res = [], [] + for box, rec_reuslt in zip(dt_boxes, rec_res): + text, score = rec_reuslt + if score >= self.drop_score: + filter_boxes.append(box) + filter_rec_res.append(rec_reuslt) + return filter_boxes, filter_rec_res def sorted_boxes(dt_boxes): @@ -117,8 +129,8 @@ def sorted_boxes(dt_boxes): _boxes = list(sorted_boxes) for i in range(num_boxes - 1): - if abs(_boxes[i+1][0][1] - _boxes[i][0][1]) < 10 and \ - (_boxes[i + 1][0][0] < _boxes[i][0][0]): + if abs(_boxes[i + 1][0][1] - _boxes[i][0][1]) < 10 and \ + (_boxes[i + 1][0][0] < _boxes[i][0][0]): tmp = _boxes[i] _boxes[i] = _boxes[i + 1] _boxes[i + 1] = tmp @@ -143,12 +155,8 @@ def main(args): elapse = time.time() - starttime logger.info("Predict time of %s: %.3fs" % (image_file, elapse)) - dt_num = len(dt_boxes) - for dno in range(dt_num): - text, score = rec_res[dno] - if score >= drop_score: - text_str = "%s, %.3f" % (text, score) - logger.info(text_str) + for text, score in rec_res: + logger.info("{}, {:.3f}".format(text, score)) if is_visualize: image = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB)) @@ -174,5 +182,4 @@ def main(args): if __name__ == "__main__": - logger = get_logger() - main(utility.parse_args()) + main(utility.parse_args()) \ No newline at end of file -- GitLab