diff --git a/paddleocr.py b/paddleocr.py index 3c356bb3ee698c483163a488859e2d4525a2aa46..65bca7ae243e15e4788b5b637be65d57cf9504e5 100644 --- a/paddleocr.py +++ b/paddleocr.py @@ -29,25 +29,19 @@ from tools.infer import predict_system from ppocr.utils.utility import initial_logger logger = initial_logger() -from ppocr.utils.utility import check_and_read_gif +from ppocr.utils.utility import check_and_read_gif, get_image_file_list __all__ = ['PaddleOCR'] model_params = { - 'ch_det_mv3_db': { - 'url': - 'https://paddleocr.bj.bcebos.com/ch_models/ch_det_mv3_db_infer.tar', - 'algorithm': 'DB', - }, - 'ch_rec_mv3_crnn_enhance': { - 'url': - 'https://paddleocr.bj.bcebos.com/ch_models/ch_rec_mv3_crnn_enhance_infer.tar', - 'algorithm': 'CRNN' - }, + 'det': 'https://paddleocr.bj.bcebos.com/ch_models/ch_det_mv3_db_infer.tar', + 'rec': + 'https://paddleocr.bj.bcebos.com/ch_models/ch_rec_mv3_crnn_enhance_infer.tar', } SUPPORT_DET_MODEL = ['DB'] -SUPPORT_REC_MODEL = ['Rosetta', 'CRNN', 'STARNet', 'RARE'] +SUPPORT_REC_MODEL = ['CRNN'] +BASE_DIR = os.path.expanduser("~/.paddleocr/") def download_with_progressbar(url, save_path): @@ -65,34 +59,29 @@ def download_with_progressbar(url, save_path): sys.exit(0) -def download_and_unzip(url, model_storage_directory): - tmp_path = os.path.join(model_storage_directory, url.split('/')[-1]) - print('download {} to {}'.format(url, tmp_path)) - os.makedirs(model_storage_directory, exist_ok=True) - download_with_progressbar(url, tmp_path) - with tarfile.open(tmp_path, 'r') as tarObj: - for filename in tarObj.getnames(): - tarObj.extract(filename, model_storage_directory) - os.remove(tmp_path) - - -def maybe_download(model_storage_directory, model_name, mode='det'): - algorithm = None +def maybe_download(model_storage_directory, url): # using custom model - if os.path.exists(os.path.join(model_name, 'model')) and os.path.exists( - os.path.join(model_name, 'params')): - return model_name, algorithm - # using the model of paddleocr - model_path = os.path.join(model_storage_directory, model_name) - if not os.path.exists(os.path.join(model_path, - 'model')) or not os.path.exists( - os.path.join(model_path, 'params')): - assert model_name in model_params, 'model must in {}'.format( - model_params.keys()) - download_and_unzip(model_params[model_name]['url'], - model_storage_directory) - algorithm = model_params[model_name]['algorithm'] - return model_path, algorithm + if not os.path.exists(os.path.join( + model_storage_directory, 'model')) or not os.path.exists( + os.path.join(model_storage_directory, 'params')): + tmp_path = os.path.join(model_storage_directory, url.split('/')[-1]) + print('download {} to {}'.format(url, tmp_path)) + os.makedirs(model_storage_directory, exist_ok=True) + download_with_progressbar(url, tmp_path) + with tarfile.open(tmp_path, 'r') as tarObj: + for member in tarObj.getmembers(): + if "model" in member.name: + filename = 'model' + elif "params" in member.name: + filename = 'params' + else: + continue + file = tarObj.extractfile(member) + with open( + os.path.join(model_storage_directory, filename), + 'wb') as f: + f.write(file.read()) + os.remove(tmp_path) def parse_args(): @@ -111,7 +100,7 @@ def parse_args(): # params for text detector parser.add_argument("--image_dir", type=str) parser.add_argument("--det_algorithm", type=str, default='DB') - parser.add_argument("--det_model_name", type=str, default='ch_det_mv3_db') + parser.add_argument("--det_model_dir", type=str, default=None) parser.add_argument("--det_max_side_len", type=float, default=960) # DB parmas @@ -126,11 +115,11 @@ def parse_args(): # params for text recognizer parser.add_argument("--rec_algorithm", type=str, default='CRNN') - parser.add_argument( - "--rec_model_name", type=str, default='ch_rec_mv3_crnn_enhance') + parser.add_argument("--rec_model_dir", type=str, default=None) parser.add_argument("--rec_image_shape", type=str, default="3, 32, 320") parser.add_argument("--rec_char_type", type=str, default='ch') parser.add_argument("--rec_batch_num", type=int, default=30) + parser.add_argument("--max_text_length", type=int, default=25) parser.add_argument( "--rec_char_dict_path", type=str, @@ -138,53 +127,30 @@ def parse_args(): parser.add_argument("--use_space_char", type=bool, default=True) parser.add_argument("--enable_mkldnn", type=bool, default=False) - parser.add_argument("--model_storage_directory", type=str, default=False) parser.add_argument("--det", type=str2bool, default=True) parser.add_argument("--rec", type=str2bool, default=True) return parser.parse_args() class PaddleOCR(predict_system.TextSystem): - def __init__(self, - det_model_name='ch_det_mv3_db', - rec_model_name='ch_rec_mv3_crnn_enhance', - model_storage_directory=None, - log_level=20, - **kwargs): + def __init__(self, **kwargs): """ paddleocr package args: - det_model_name: det_model name, keep same with filename in paddleocr. default is ch_det_mv3_db - det_model_name: rec_model name, keep same with filename in paddleocr. default is ch_rec_mv3_crnn_enhance - model_storage_directory: model save path. default is ~/.paddleocr - det model will save to model_storage_directory/det_model - rec model will save to model_storage_directory/rec_model - log_level: **kwargs: other params show in paddleocr --help """ - logger.setLevel(log_level) postprocess_params = parse_args() - # init model dir - if model_storage_directory: - self.model_storage_directory = model_storage_directory - else: - self.model_storage_directory = os.path.expanduser( - "~/.paddleocr/") + '/model' - Path(self.model_storage_directory).mkdir(parents=True, exist_ok=True) + postprocess_params.__dict__.update(**kwargs) + # init model dir + if postprocess_params.det_model_dir is None: + postprocess_params.det_model_dir = os.path.join(BASE_DIR, 'det') + if postprocess_params.rec_model_dir is None: + postprocess_params.rec_model_dir = os.path.join(BASE_DIR, 'rec') + print(postprocess_params) # download model - det_model_path, det_algorithm = maybe_download( - self.model_storage_directory, det_model_name, 'det') - rec_model_path, rec_algorithm = maybe_download( - self.model_storage_directory, rec_model_name, 'rec') - # update model and post_process params - postprocess_params.__dict__.update(**kwargs) - postprocess_params.det_model_dir = det_model_path - postprocess_params.rec_model_dir = rec_model_path - if det_algorithm is not None: - postprocess_params.det_algorithm = det_algorithm - if rec_algorithm is not None: - postprocess_params.rec_algorithm = rec_algorithm + maybe_download(postprocess_params.det_model_dir, model_params['det']) + maybe_download(postprocess_params.rec_model_dir, model_params['rec']) if postprocess_params.det_algorithm not in SUPPORT_DET_MODEL: logger.error('det_algorithm must in {}'.format(SUPPORT_DET_MODEL)) @@ -229,3 +195,18 @@ class PaddleOCR(predict_system.TextSystem): img = [img] rec_res, elapse = self.text_recognizer(img) return rec_res + + +def main(): + # for com + args = parse_args() + image_file_list = get_image_file_list(args.image_dir) + if len(image_file_list) == 0: + logger.error('no images find in {}'.format(args.image_dir)) + return + ocr_engine = PaddleOCR() + for img_path in image_file_list: + print(img_path) + result = ocr_engine.ocr(img_path, det=args.det, rec=args.rec) + for line in result: + print(line)