未验证 提交 ad4853db 编写于 作者: Z zhoujun 提交者: GitHub

Merge pull request #2925 from WenmuZhou/Optimizing_parameters

combine args in paddleocr and ppocr/infer/utility
...@@ -59,7 +59,7 @@ im_show.save('result.jpg') ...@@ -59,7 +59,7 @@ im_show.save('result.jpg')
from paddleocr import PaddleOCR, draw_ocr from paddleocr import PaddleOCR, draw_ocr
ocr = PaddleOCR() # need to run only once to download and load model into memory ocr = PaddleOCR() # need to run only once to download and load model into memory
img_path = 'PaddleOCR/doc/imgs/11.jpg' img_path = 'PaddleOCR/doc/imgs/11.jpg'
result = ocr.ocr(img_path) result = ocr.ocr(img_path,cls=False)
for line in result: for line in result:
print(line) print(line)
......
...@@ -59,7 +59,7 @@ Visualization of results ...@@ -59,7 +59,7 @@ Visualization of results
from paddleocr import PaddleOCR,draw_ocr from paddleocr import PaddleOCR,draw_ocr
ocr = PaddleOCR(lang='en') # need to run only once to download and load model into memory ocr = PaddleOCR(lang='en') # need to run only once to download and load model into memory
img_path = 'PaddleOCR/doc/imgs_en/img_12.jpg' img_path = 'PaddleOCR/doc/imgs_en/img_12.jpg'
result = ocr.ocr(img_path) result = ocr.ocr(img_path, cls=False)
for line in result: for line in result:
print(line) print(line)
......
...@@ -30,7 +30,7 @@ from ppocr.utils.logging import get_logger ...@@ -30,7 +30,7 @@ from ppocr.utils.logging import get_logger
logger = get_logger() logger = get_logger()
from ppocr.utils.utility import check_and_read_gif, get_image_file_list from ppocr.utils.utility import check_and_read_gif, get_image_file_list
from tools.infer.utility import draw_ocr from tools.infer.utility import draw_ocr, init_args, str2bool
__all__ = ['PaddleOCR'] __all__ = ['PaddleOCR']
...@@ -167,106 +167,24 @@ def maybe_download(model_storage_directory, url): ...@@ -167,106 +167,24 @@ def maybe_download(model_storage_directory, url):
os.remove(tmp_path) os.remove(tmp_path)
def parse_args(mMain=True, add_help=True): def parse_args(mMain=True):
import argparse import argparse
parser = init_args()
def str2bool(v): parser.add_help = mMain
return v.lower() in ("true", "t", "1")
if mMain:
parser = argparse.ArgumentParser(add_help=add_help)
# params for prediction engine
parser.add_argument("--use_gpu", type=str2bool, default=True)
parser.add_argument("--ir_optim", type=str2bool, default=True)
parser.add_argument("--use_tensorrt", type=str2bool, default=False)
parser.add_argument("--gpu_mem", type=int, default=8000)
# params for text detector
parser.add_argument("--image_dir", type=str)
parser.add_argument("--det_algorithm", type=str, default='DB')
parser.add_argument("--det_model_dir", type=str, default=None)
parser.add_argument("--det_limit_side_len", type=float, default=960)
parser.add_argument("--det_limit_type", type=str, default='max')
# DB parmas
parser.add_argument("--det_db_thresh", type=float, default=0.3)
parser.add_argument("--det_db_box_thresh", type=float, default=0.5)
parser.add_argument("--det_db_unclip_ratio", type=float, default=1.6)
parser.add_argument("--use_dilation", type=bool, default=False)
parser.add_argument("--det_db_score_mode", type=str, default="fast")
# EAST parmas
parser.add_argument("--det_east_score_thresh", type=float, default=0.8)
parser.add_argument("--det_east_cover_thresh", type=float, default=0.1)
parser.add_argument("--det_east_nms_thresh", type=float, default=0.2)
# params for text recognizer
parser.add_argument("--rec_algorithm", type=str, default='CRNN')
parser.add_argument("--rec_model_dir", type=str, default=None)
parser.add_argument("--rec_image_shape", type=str, default="3, 32, 320")
parser.add_argument("--rec_char_type", type=str, default='ch')
parser.add_argument("--rec_batch_num", type=int, default=6)
parser.add_argument("--max_text_length", type=int, default=25)
parser.add_argument("--rec_char_dict_path", type=str, default=None)
parser.add_argument("--use_space_char", type=bool, default=True)
parser.add_argument("--drop_score", type=float, default=0.5)
# params for text classifier
parser.add_argument("--cls_model_dir", type=str, default=None)
parser.add_argument("--cls_image_shape", type=str, default="3, 48, 192")
parser.add_argument("--label_list", type=list, default=['0', '180'])
parser.add_argument("--cls_batch_num", type=int, default=6)
parser.add_argument("--cls_thresh", type=float, default=0.9)
parser.add_argument("--enable_mkldnn", type=bool, default=False)
parser.add_argument("--use_zero_copy_run", type=bool, default=False)
parser.add_argument("--use_pdserving", type=str2bool, default=False)
parser.add_argument("--lang", type=str, default='ch') parser.add_argument("--lang", type=str, default='ch')
parser.add_argument("--det", type=str2bool, default=True) parser.add_argument("--det", type=str2bool, default=True)
parser.add_argument("--rec", type=str2bool, default=True) parser.add_argument("--rec", type=str2bool, default=True)
parser.add_argument("--use_angle_cls", type=str2bool, default=False)
for action in parser._actions:
if action.dest == 'rec_char_dict_path':
action.default = None
if mMain:
return parser.parse_args() return parser.parse_args()
else: else:
return argparse.Namespace( inference_args_dict = {}
use_gpu=True, for action in parser._actions:
ir_optim=True, inference_args_dict[action.dest] = action.default
use_tensorrt=False, return argparse.Namespace(**inference_args_dict)
gpu_mem=8000,
image_dir='',
det_algorithm='DB',
det_model_dir=None,
det_limit_side_len=960,
det_limit_type='max',
det_db_thresh=0.3,
det_db_box_thresh=0.5,
det_db_unclip_ratio=1.6,
use_dilation=False,
det_db_score_mode="fast",
det_east_score_thresh=0.8,
det_east_cover_thresh=0.1,
det_east_nms_thresh=0.2,
rec_algorithm='CRNN',
rec_model_dir=None,
rec_image_shape="3, 32, 320",
rec_char_type='ch',
rec_batch_num=6,
max_text_length=25,
rec_char_dict_path=None,
use_space_char=True,
drop_score=0.5,
cls_model_dir=None,
cls_image_shape="3, 48, 192",
label_list=['0', '180'],
cls_batch_num=6,
cls_thresh=0.9,
enable_mkldnn=False,
use_zero_copy_run=False,
use_pdserving=False,
lang='ch',
det=True,
rec=True,
use_angle_cls=False)
class PaddleOCR(predict_system.TextSystem): class PaddleOCR(predict_system.TextSystem):
...@@ -276,7 +194,7 @@ class PaddleOCR(predict_system.TextSystem): ...@@ -276,7 +194,7 @@ class PaddleOCR(predict_system.TextSystem):
args: args:
**kwargs: other params show in paddleocr --help **kwargs: other params show in paddleocr --help
""" """
postprocess_params = parse_args(mMain=False, add_help=False) postprocess_params = parse_args(mMain=False)
postprocess_params.__dict__.update(**kwargs) postprocess_params.__dict__.update(**kwargs)
self.use_angle_cls = postprocess_params.use_angle_cls self.use_angle_cls = postprocess_params.use_angle_cls
lang = postprocess_params.lang lang = postprocess_params.lang
...@@ -346,7 +264,7 @@ class PaddleOCR(predict_system.TextSystem): ...@@ -346,7 +264,7 @@ class PaddleOCR(predict_system.TextSystem):
# init det_model and rec_model # init det_model and rec_model
super().__init__(postprocess_params) super().__init__(postprocess_params)
def ocr(self, img, det=True, rec=True, cls=False): def ocr(self, img, det=True, rec=True, cls=True):
""" """
ocr with paddleocr ocr with paddleocr
args: args:
...@@ -358,9 +276,7 @@ class PaddleOCR(predict_system.TextSystem): ...@@ -358,9 +276,7 @@ class PaddleOCR(predict_system.TextSystem):
if isinstance(img, list) and det == True: if isinstance(img, list) and det == True:
logger.error('When input a list of images, det must be false') logger.error('When input a list of images, det must be false')
exit(0) exit(0)
if cls == False: if cls == True and self.use_angle_cls == False:
self.use_angle_cls = False
elif cls == True and self.use_angle_cls == False:
logger.warning( logger.warning(
'Since the angle classifier is not initialized, the angle classifier will not be uesd during the forward process' 'Since the angle classifier is not initialized, the angle classifier will not be uesd during the forward process'
) )
...@@ -382,7 +298,7 @@ class PaddleOCR(predict_system.TextSystem): ...@@ -382,7 +298,7 @@ class PaddleOCR(predict_system.TextSystem):
if isinstance(img, np.ndarray) and len(img.shape) == 2: if isinstance(img, np.ndarray) and len(img.shape) == 2:
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
if det and rec: if det and rec:
dt_boxes, rec_res = self.__call__(img) dt_boxes, rec_res = self.__call__(img, cls)
return [[box.tolist(), res] for box, res in zip(dt_boxes, rec_res)] return [[box.tolist(), res] for box, res in zip(dt_boxes, rec_res)]
elif det and not rec: elif det and not rec:
dt_boxes, elapse = self.text_detector(img) dt_boxes, elapse = self.text_detector(img)
...@@ -392,7 +308,7 @@ class PaddleOCR(predict_system.TextSystem): ...@@ -392,7 +308,7 @@ class PaddleOCR(predict_system.TextSystem):
else: else:
if not isinstance(img, list): if not isinstance(img, list):
img = [img] img = [img]
if self.use_angle_cls: if self.use_angle_cls and cls:
img, cls_res, elapse = self.text_classifier(img) img, cls_res, elapse = self.text_classifier(img)
if not rec: if not rec:
return cls_res return cls_res
......
...@@ -85,7 +85,7 @@ class TextSystem(object): ...@@ -85,7 +85,7 @@ class TextSystem(object):
cv2.imwrite("./output/img_crop_%d.jpg" % bno, img_crop_list[bno]) cv2.imwrite("./output/img_crop_%d.jpg" % bno, img_crop_list[bno])
logger.info(bno, rec_res[bno]) logger.info(bno, rec_res[bno])
def __call__(self, img): def __call__(self, img, cls=True):
ori_im = img.copy() ori_im = img.copy()
dt_boxes, elapse = self.text_detector(img) dt_boxes, elapse = self.text_detector(img)
logger.info("dt_boxes num : {}, elapse : {}".format( logger.info("dt_boxes num : {}, elapse : {}".format(
...@@ -100,7 +100,7 @@ class TextSystem(object): ...@@ -100,7 +100,7 @@ class TextSystem(object):
tmp_box = copy.deepcopy(dt_boxes[bno]) tmp_box = copy.deepcopy(dt_boxes[bno])
img_crop = self.get_rotate_crop_image(ori_im, tmp_box) img_crop = self.get_rotate_crop_image(ori_im, tmp_box)
img_crop_list.append(img_crop) img_crop_list.append(img_crop)
if self.use_angle_cls: if self.use_angle_cls and cls:
img_crop_list, angle_list, elapse = self.text_classifier( img_crop_list, angle_list, elapse = self.text_classifier(
img_crop_list) img_crop_list)
logger.info("cls num : {}, elapse : {}".format( logger.info("cls num : {}, elapse : {}".format(
......
...@@ -23,13 +23,15 @@ import math ...@@ -23,13 +23,15 @@ import math
from paddle import inference from paddle import inference
import time import time
from ppocr.utils.logging import get_logger from ppocr.utils.logging import get_logger
logger = get_logger() logger = get_logger()
def parse_args(): def str2bool(v):
def str2bool(v):
return v.lower() in ("true", "t", "1") return v.lower() in ("true", "t", "1")
def init_args():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
# params for prediction engine # params for prediction engine
parser.add_argument("--use_gpu", type=str2bool, default=True) parser.add_argument("--use_gpu", type=str2bool, default=True)
...@@ -108,6 +110,11 @@ def parse_args(): ...@@ -108,6 +110,11 @@ def parse_args():
parser.add_argument("--total_process_num", type=int, default=1) parser.add_argument("--total_process_num", type=int, default=1)
parser.add_argument("--process_id", type=int, default=0) parser.add_argument("--process_id", type=int, default=0)
return parser
def parse_args():
parser = init_args()
return parser.parse_args() return parser.parse_args()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册