提交 eaf38b9b 编写于 作者: W WenmuZhou

combine args in paddleocr and ppocr/infer/utility

上级 5d24736a
...@@ -59,7 +59,7 @@ im_show.save('result.jpg') ...@@ -59,7 +59,7 @@ im_show.save('result.jpg')
from paddleocr import PaddleOCR, draw_ocr from paddleocr import PaddleOCR, draw_ocr
ocr = PaddleOCR() # need to run only once to download and load model into memory ocr = PaddleOCR() # need to run only once to download and load model into memory
img_path = 'PaddleOCR/doc/imgs/11.jpg' img_path = 'PaddleOCR/doc/imgs/11.jpg'
result = ocr.ocr(img_path) result = ocr.ocr(img_path,cls=False)
for line in result: for line in result:
print(line) print(line)
......
...@@ -59,7 +59,7 @@ Visualization of results ...@@ -59,7 +59,7 @@ Visualization of results
from paddleocr import PaddleOCR,draw_ocr from paddleocr import PaddleOCR,draw_ocr
ocr = PaddleOCR(lang='en') # need to run only once to download and load model into memory ocr = PaddleOCR(lang='en') # need to run only once to download and load model into memory
img_path = 'PaddleOCR/doc/imgs_en/img_12.jpg' img_path = 'PaddleOCR/doc/imgs_en/img_12.jpg'
result = ocr.ocr(img_path) result = ocr.ocr(img_path, cls=False)
for line in result: for line in result:
print(line) print(line)
......
...@@ -30,7 +30,7 @@ from ppocr.utils.logging import get_logger ...@@ -30,7 +30,7 @@ from ppocr.utils.logging import get_logger
logger = get_logger() logger = get_logger()
from ppocr.utils.utility import check_and_read_gif, get_image_file_list from ppocr.utils.utility import check_and_read_gif, get_image_file_list
from tools.infer.utility import draw_ocr from tools.infer.utility import draw_ocr, inference_args_list, str2bool, parse_args
__all__ = ['PaddleOCR'] __all__ = ['PaddleOCR']
...@@ -167,106 +167,36 @@ def maybe_download(model_storage_directory, url): ...@@ -167,106 +167,36 @@ def maybe_download(model_storage_directory, url):
os.remove(tmp_path) os.remove(tmp_path)
def parse_args(mMain=True, add_help=True): def parse_args_whl(mMain=True):
import argparse import argparse
extend_args_list = [
def str2bool(v): {
return v.lower() in ("true", "t", "1") 'name': 'lang',
'type': str,
'default': 'ch'
},
{
'name': 'det',
'type': str2bool,
'default': True
},
{
'name': 'rec',
'type': str2bool,
'default': True
},
]
for item in inference_args_list:
if item['name'] == 'rec_char_dict_path':
item['default'] = None
inference_args_list.extend(extend_args_list)
if mMain: if mMain:
parser = argparse.ArgumentParser(add_help=add_help) return parse_args()
# params for prediction engine
parser.add_argument("--use_gpu", type=str2bool, default=True)
parser.add_argument("--ir_optim", type=str2bool, default=True)
parser.add_argument("--use_tensorrt", type=str2bool, default=False)
parser.add_argument("--gpu_mem", type=int, default=8000)
# params for text detector
parser.add_argument("--image_dir", type=str)
parser.add_argument("--det_algorithm", type=str, default='DB')
parser.add_argument("--det_model_dir", type=str, default=None)
parser.add_argument("--det_limit_side_len", type=float, default=960)
parser.add_argument("--det_limit_type", type=str, default='max')
# DB parmas
parser.add_argument("--det_db_thresh", type=float, default=0.3)
parser.add_argument("--det_db_box_thresh", type=float, default=0.5)
parser.add_argument("--det_db_unclip_ratio", type=float, default=1.6)
parser.add_argument("--use_dilation", type=bool, default=False)
parser.add_argument("--det_db_score_mode", type=str, default="fast")
# EAST parmas
parser.add_argument("--det_east_score_thresh", type=float, default=0.8)
parser.add_argument("--det_east_cover_thresh", type=float, default=0.1)
parser.add_argument("--det_east_nms_thresh", type=float, default=0.2)
# params for text recognizer
parser.add_argument("--rec_algorithm", type=str, default='CRNN')
parser.add_argument("--rec_model_dir", type=str, default=None)
parser.add_argument("--rec_image_shape", type=str, default="3, 32, 320")
parser.add_argument("--rec_char_type", type=str, default='ch')
parser.add_argument("--rec_batch_num", type=int, default=6)
parser.add_argument("--max_text_length", type=int, default=25)
parser.add_argument("--rec_char_dict_path", type=str, default=None)
parser.add_argument("--use_space_char", type=bool, default=True)
parser.add_argument("--drop_score", type=float, default=0.5)
# params for text classifier
parser.add_argument("--cls_model_dir", type=str, default=None)
parser.add_argument("--cls_image_shape", type=str, default="3, 48, 192")
parser.add_argument("--label_list", type=list, default=['0', '180'])
parser.add_argument("--cls_batch_num", type=int, default=6)
parser.add_argument("--cls_thresh", type=float, default=0.9)
parser.add_argument("--enable_mkldnn", type=bool, default=False)
parser.add_argument("--use_zero_copy_run", type=bool, default=False)
parser.add_argument("--use_pdserving", type=str2bool, default=False)
parser.add_argument("--lang", type=str, default='ch')
parser.add_argument("--det", type=str2bool, default=True)
parser.add_argument("--rec", type=str2bool, default=True)
parser.add_argument("--use_angle_cls", type=str2bool, default=False)
return parser.parse_args()
else: else:
return argparse.Namespace( inference_args_dict = {}
use_gpu=True, for item in inference_args_list:
ir_optim=True, inference_args_dict[item['name']] = item['default']
use_tensorrt=False, return argparse.Namespace(**inference_args_dict)
gpu_mem=8000,
image_dir='',
det_algorithm='DB',
det_model_dir=None,
det_limit_side_len=960,
det_limit_type='max',
det_db_thresh=0.3,
det_db_box_thresh=0.5,
det_db_unclip_ratio=1.6,
use_dilation=False,
det_db_score_mode="fast",
det_east_score_thresh=0.8,
det_east_cover_thresh=0.1,
det_east_nms_thresh=0.2,
rec_algorithm='CRNN',
rec_model_dir=None,
rec_image_shape="3, 32, 320",
rec_char_type='ch',
rec_batch_num=6,
max_text_length=25,
rec_char_dict_path=None,
use_space_char=True,
drop_score=0.5,
cls_model_dir=None,
cls_image_shape="3, 48, 192",
label_list=['0', '180'],
cls_batch_num=6,
cls_thresh=0.9,
enable_mkldnn=False,
use_zero_copy_run=False,
use_pdserving=False,
lang='ch',
det=True,
rec=True,
use_angle_cls=False)
class PaddleOCR(predict_system.TextSystem): class PaddleOCR(predict_system.TextSystem):
...@@ -276,7 +206,7 @@ class PaddleOCR(predict_system.TextSystem): ...@@ -276,7 +206,7 @@ class PaddleOCR(predict_system.TextSystem):
args: args:
**kwargs: other params show in paddleocr --help **kwargs: other params show in paddleocr --help
""" """
postprocess_params = parse_args(mMain=False, add_help=False) postprocess_params = parse_args_whl(mMain=False)
postprocess_params.__dict__.update(**kwargs) postprocess_params.__dict__.update(**kwargs)
self.use_angle_cls = postprocess_params.use_angle_cls self.use_angle_cls = postprocess_params.use_angle_cls
lang = postprocess_params.lang lang = postprocess_params.lang
...@@ -346,7 +276,7 @@ class PaddleOCR(predict_system.TextSystem): ...@@ -346,7 +276,7 @@ class PaddleOCR(predict_system.TextSystem):
# init det_model and rec_model # init det_model and rec_model
super().__init__(postprocess_params) super().__init__(postprocess_params)
def ocr(self, img, det=True, rec=True, cls=False): def ocr(self, img, det=True, rec=True, cls=True):
""" """
ocr with paddleocr ocr with paddleocr
args: args:
...@@ -358,9 +288,7 @@ class PaddleOCR(predict_system.TextSystem): ...@@ -358,9 +288,7 @@ class PaddleOCR(predict_system.TextSystem):
if isinstance(img, list) and det == True: if isinstance(img, list) and det == True:
logger.error('When input a list of images, det must be false') logger.error('When input a list of images, det must be false')
exit(0) exit(0)
if cls == False: if cls == True and self.use_angle_cls == False:
self.use_angle_cls = False
elif cls == True and self.use_angle_cls == False:
logger.warning( logger.warning(
'Since the angle classifier is not initialized, the angle classifier will not be uesd during the forward process' 'Since the angle classifier is not initialized, the angle classifier will not be uesd during the forward process'
) )
...@@ -382,7 +310,7 @@ class PaddleOCR(predict_system.TextSystem): ...@@ -382,7 +310,7 @@ class PaddleOCR(predict_system.TextSystem):
if isinstance(img, np.ndarray) and len(img.shape) == 2: if isinstance(img, np.ndarray) and len(img.shape) == 2:
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
if det and rec: if det and rec:
dt_boxes, rec_res = self.__call__(img) dt_boxes, rec_res = self.__call__(img, cls)
return [[box.tolist(), res] for box, res in zip(dt_boxes, rec_res)] return [[box.tolist(), res] for box, res in zip(dt_boxes, rec_res)]
elif det and not rec: elif det and not rec:
dt_boxes, elapse = self.text_detector(img) dt_boxes, elapse = self.text_detector(img)
...@@ -392,7 +320,7 @@ class PaddleOCR(predict_system.TextSystem): ...@@ -392,7 +320,7 @@ class PaddleOCR(predict_system.TextSystem):
else: else:
if not isinstance(img, list): if not isinstance(img, list):
img = [img] img = [img]
if self.use_angle_cls: if self.use_angle_cls and cls:
img, cls_res, elapse = self.text_classifier(img) img, cls_res, elapse = self.text_classifier(img)
if not rec: if not rec:
return cls_res return cls_res
...@@ -402,7 +330,7 @@ class PaddleOCR(predict_system.TextSystem): ...@@ -402,7 +330,7 @@ class PaddleOCR(predict_system.TextSystem):
def main(): def main():
# for cmd # for cmd
args = parse_args(mMain=True) args = parse_args_whl(mMain=True)
image_dir = args.image_dir image_dir = args.image_dir
if image_dir.startswith('http'): if image_dir.startswith('http'):
download_with_progressbar(image_dir, 'tmp.jpg') download_with_progressbar(image_dir, 'tmp.jpg')
......
...@@ -85,7 +85,7 @@ class TextSystem(object): ...@@ -85,7 +85,7 @@ class TextSystem(object):
cv2.imwrite("./output/img_crop_%d.jpg" % bno, img_crop_list[bno]) cv2.imwrite("./output/img_crop_%d.jpg" % bno, img_crop_list[bno])
logger.info(bno, rec_res[bno]) logger.info(bno, rec_res[bno])
def __call__(self, img): def __call__(self, img, cls=True):
ori_im = img.copy() ori_im = img.copy()
dt_boxes, elapse = self.text_detector(img) dt_boxes, elapse = self.text_detector(img)
logger.info("dt_boxes num : {}, elapse : {}".format( logger.info("dt_boxes num : {}, elapse : {}".format(
...@@ -100,7 +100,7 @@ class TextSystem(object): ...@@ -100,7 +100,7 @@ class TextSystem(object):
tmp_box = copy.deepcopy(dt_boxes[bno]) tmp_box = copy.deepcopy(dt_boxes[bno])
img_crop = self.get_rotate_crop_image(ori_im, tmp_box) img_crop = self.get_rotate_crop_image(ori_im, tmp_box)
img_crop_list.append(img_crop) img_crop_list.append(img_crop)
if self.use_angle_cls: if self.use_angle_cls and cls:
img_crop_list, angle_list, elapse = self.text_classifier( img_crop_list, angle_list, elapse = self.text_classifier(
img_crop_list) img_crop_list)
logger.info("cls num : {}, elapse : {}".format( logger.info("cls num : {}, elapse : {}".format(
......
...@@ -23,87 +23,288 @@ import math ...@@ -23,87 +23,288 @@ import math
from paddle import inference from paddle import inference
def parse_args(): def str2bool(v):
def str2bool(v): return v.lower() in ("true", "t", "1")
return v.lower() in ("true", "t", "1")
parser = argparse.ArgumentParser()
# params for prediction engine
parser.add_argument("--use_gpu", type=str2bool, default=True)
parser.add_argument("--ir_optim", type=str2bool, default=True)
parser.add_argument("--use_tensorrt", type=str2bool, default=False)
parser.add_argument("--use_fp16", type=str2bool, default=False)
parser.add_argument("--gpu_mem", type=int, default=500)
inference_args_list = [
# params for prediction engine
{
'name': 'use_gpu',
'type': str2bool,
'default': True
},
{
'name': 'ir_optim',
'type': str2bool,
'default': True
},
{
'name': 'use_tensorrt',
'type': str2bool,
'default': False
},
{
'name': 'use_fp16',
'type': str2bool,
'default': False
},
{
'name': 'enable_mkldnn',
'type': str2bool,
'default': False
},
{
'name': 'use_pdserving',
'type': str2bool,
'default': False
},
{
'name': 'use_mp',
'type': str2bool,
'default': False
},
{
'name': 'total_process_num',
'type': int,
'default': 1
},
{
'name': 'process_id',
'type': int,
'default': 0
},
{
'name': 'gpu_mem',
'type': int,
'default': 500
},
# params for text detector # params for text detector
parser.add_argument("--image_dir", type=str) {
parser.add_argument("--det_algorithm", type=str, default='DB') 'name': 'image_dir',
parser.add_argument("--det_model_dir", type=str) 'type': str,
parser.add_argument("--det_limit_side_len", type=float, default=960) 'default': None
parser.add_argument("--det_limit_type", type=str, default='max') },
{
'name': 'det_algorithm',
'type': str,
'default': 'DB'
},
{
'name': 'det_model_dir',
'type': str,
'default': None
},
{
'name': 'det_limit_side_len',
'type': float,
'default': 960
},
{
'name': 'det_limit_type',
'type': str,
'default': 'max'
},
# DB parmas # DB parmas
parser.add_argument("--det_db_thresh", type=float, default=0.3) {
parser.add_argument("--det_db_box_thresh", type=float, default=0.5) 'name': 'det_db_thresh',
parser.add_argument("--det_db_unclip_ratio", type=float, default=1.6) 'type': float,
parser.add_argument("--max_batch_size", type=int, default=10) 'default': 0.3
parser.add_argument("--use_dilation", type=bool, default=False) },
parser.add_argument("--det_db_score_mode", type=str, default="fast") {
'name': 'det_db_box_thresh',
'type': float,
'default': 0.5
},
{
'name': 'det_db_unclip_ratio',
'type': float,
'default': 1.6
},
{
'name': 'max_batch_size',
'type': int,
'default': 10
},
{
'name': 'use_dilation',
'type': str2bool,
'default': False
},
{
'name': 'det_db_score_mode',
'type': str,
'default': 'fast'
},
# EAST parmas # EAST parmas
parser.add_argument("--det_east_score_thresh", type=float, default=0.8) {
parser.add_argument("--det_east_cover_thresh", type=float, default=0.1) 'name': 'det_east_score_thresh',
parser.add_argument("--det_east_nms_thresh", type=float, default=0.2) 'type': float,
'default': 0.8
},
{
'name': 'det_east_cover_thresh',
'type': float,
'default': 0.1
},
{
'name': 'det_east_nms_thresh',
'type': float,
'default': 0.2
},
# SAST parmas # SAST parmas
parser.add_argument("--det_sast_score_thresh", type=float, default=0.5) {
parser.add_argument("--det_sast_nms_thresh", type=float, default=0.2) 'name': 'det_sast_score_thresh',
parser.add_argument("--det_sast_polygon", type=bool, default=False) 'type': float,
'default': 0.5
},
{
'name': 'det_sast_nms_thresh',
'type': float,
'default': 0.2
},
{
'name': 'det_sast_polygon',
'type': str2bool,
'default': False
},
# params for text recognizer # params for text recognizer
parser.add_argument("--rec_algorithm", type=str, default='CRNN') {
parser.add_argument("--rec_model_dir", type=str) 'name': 'rec_algorithm',
parser.add_argument("--rec_image_shape", type=str, default="3, 32, 320") 'type': str,
parser.add_argument("--rec_char_type", type=str, default='ch') 'default': 'CRNN'
parser.add_argument("--rec_batch_num", type=int, default=6) },
parser.add_argument("--max_text_length", type=int, default=25) {
parser.add_argument( 'name': 'rec_model_dir',
"--rec_char_dict_path", 'type': str,
type=str, 'default': None
default="./ppocr/utils/ppocr_keys_v1.txt") },
parser.add_argument("--use_space_char", type=str2bool, default=True) {
parser.add_argument( 'name': 'rec_image_shape',
"--vis_font_path", type=str, default="./doc/fonts/simfang.ttf") 'type': str,
parser.add_argument("--drop_score", type=float, default=0.5) 'default': '3, 32, 320'
},
{
'name': 'rec_char_type',
'type': str,
'default': "ch"
},
{
'name': 'rec_batch_num',
'type': int,
'default': 6
},
{
'name': 'max_text_length',
'type': int,
'default': 25
},
{
'name': 'rec_char_dict_path',
'type': str,
'default': './ppocr/utils/ppocr_keys_v1.txt'
},
{
'name': 'use_space_char',
'type': str2bool,
'default': True
},
{
'name': 'vis_font_path',
'type': str,
'default': './doc/fonts/simfang.ttf'
},
{
'name': 'drop_score',
'type': float,
'default': 0.5
},
# params for e2e # params for e2e
parser.add_argument("--e2e_algorithm", type=str, default='PGNet') {
parser.add_argument("--e2e_model_dir", type=str) 'name': 'e2e_algorithm',
parser.add_argument("--e2e_limit_side_len", type=float, default=768) 'type': str,
parser.add_argument("--e2e_limit_type", type=str, default='max') 'default': 'PGNet'
},
{
'name': 'e2e_model_dir',
'type': str,
'default': None
},
{
'name': 'e2e_limit_side_len',
'type': float,
'default': 768
},
{
'name': 'e2e_limit_type',
'type': str,
'default': 'max'
},
# PGNet parmas # PGNet parmas
parser.add_argument("--e2e_pgnet_score_thresh", type=float, default=0.5) {
parser.add_argument( 'name': 'e2e_pgnet_score_thresh',
"--e2e_char_dict_path", type=str, default="./ppocr/utils/ic15_dict.txt") 'type': float,
parser.add_argument("--e2e_pgnet_valid_set", type=str, default='totaltext') 'default': 0.5
parser.add_argument("--e2e_pgnet_polygon", type=bool, default=True) },
parser.add_argument("--e2e_pgnet_mode", type=str, default='fast') {
'name': 'e2e_char_dict_path',
'type': str,
'default': './ppocr/utils/ic15_dict.txt'
},
{
'name': 'e2e_pgnet_valid_set',
'type': str,
'default': 'totaltext'
},
{
'name': 'e2e_pgnet_polygon',
'type': str2bool,
'default': True
},
{
'name': 'e2e_pgnet_mode',
'type': str,
'default': 'fast'
},
# params for text classifier # params for text classifier
parser.add_argument("--use_angle_cls", type=str2bool, default=False) {
parser.add_argument("--cls_model_dir", type=str) 'name': 'use_angle_cls',
parser.add_argument("--cls_image_shape", type=str, default="3, 48, 192") 'type': str2bool,
parser.add_argument("--label_list", type=list, default=['0', '180']) 'default': False
parser.add_argument("--cls_batch_num", type=int, default=6) },
parser.add_argument("--cls_thresh", type=float, default=0.9) {
'name': 'cls_model_dir',
'type': str,
'default': None
},
{
'name': 'cls_image_shape',
'type': str,
'default': '3, 48, 192'
},
{
'name': 'label_list',
'type': list,
'default': ['0', '180']
},
{
'name': 'cls_batch_num',
'type': int,
'default': 6
},
{
'name': 'cls_thresh',
'type': float,
'default': 0.9
},
]
parser.add_argument("--enable_mkldnn", type=str2bool, default=False)
parser.add_argument("--use_pdserving", type=str2bool, default=False)
parser.add_argument("--use_mp", type=str2bool, default=False)
parser.add_argument("--total_process_num", type=int, default=1)
parser.add_argument("--process_id", type=int, default=0)
def parse_args():
parser = argparse.ArgumentParser()
for item in inference_args_list:
parser.add_argument(
'--' + item['name'], type=item['type'], default=item['default'])
return parser.parse_args() return parser.parse_args()
...@@ -146,7 +347,7 @@ def create_predictor(args, mode, logger): ...@@ -146,7 +347,7 @@ def create_predictor(args, mode, logger):
config.set_mkldnn_cache_capacity(10) config.set_mkldnn_cache_capacity(10)
config.enable_mkldnn() config.enable_mkldnn()
# TODO LDOUBLEV: fix mkldnn bug when bach_size > 1 # TODO LDOUBLEV: fix mkldnn bug when bach_size > 1
#config.set_mkldnn_op({'conv2d', 'depthwise_conv2d', 'pool2d', 'batch_norm'}) # config.set_mkldnn_op({'conv2d', 'depthwise_conv2d', 'pool2d', 'batch_norm'})
args.rec_batch_num = 1 args.rec_batch_num = 1
# enable memory optim # enable memory optim
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册