From 7d472831286f84f1c3853a01eebad127412f536b Mon Sep 17 00:00:00 2001 From: WenmuZhou Date: Thu, 10 Jun 2021 14:47:23 +0800 Subject: [PATCH] support model link in model_dir params --- paddleocr.py | 29 ++++++++++++++--------------- 1 file changed, 14 insertions(+), 15 deletions(-) diff --git a/paddleocr.py b/paddleocr.py index 48c8c9c6..95896817 100644 --- a/paddleocr.py +++ b/paddleocr.py @@ -28,7 +28,7 @@ from ppocr.utils.logging import get_logger logger = get_logger() from ppocr.utils.utility import check_and_read_gif, get_image_file_list -from ppocr.utils.network import maybe_download, download_with_progressbar +from ppocr.utils.network import maybe_download, download_with_progressbar, is_link, confirm_model_dir_url from tools.infer.utility import draw_ocr, init_args, str2bool __all__ = ['PaddleOCR'] @@ -192,20 +192,19 @@ class PaddleOCR(predict_system.TextSystem): 'dict_path'] # init model dir - if params.det_model_dir is None: - params.det_model_dir = os.path.join(BASE_DIR, VERSION, - 'det', det_lang) - if params.rec_model_dir is None: - params.rec_model_dir = os.path.join(BASE_DIR, VERSION, - 'rec', lang) - if params.cls_model_dir is None: - params.cls_model_dir = os.path.join(BASE_DIR, 'cls') + params.det_model_dir, det_url = confirm_model_dir_url(params.det_model_dir, + os.path.join(BASE_DIR, VERSION, 'det', det_lang), + model_urls['det'][det_lang]) + params.rec_model_dir, rec_url = confirm_model_dir_url(params.rec_model_dir, + os.path.join(BASE_DIR, VERSION, 'rec', lang), + model_urls['rec'][lang]['url']) + params.cls_model_dir, cls_url = confirm_model_dir_url(params.cls_model_dir, + os.path.join(BASE_DIR, VERSION, 'cls'), + model_urls['cls']) # download model - maybe_download(params.det_model_dir, - model_urls['det'][det_lang]) - maybe_download(params.rec_model_dir, - model_urls['rec'][lang]['url']) - maybe_download(params.cls_model_dir, model_urls['cls']) + maybe_download(params.det_model_dir, det_url) + maybe_download(params.rec_model_dir, rec_url) + maybe_download(params.cls_model_dir, cls_url) if params.det_algorithm not in SUPPORT_DET_MODEL: logger.error('det_algorithm must in {}'.format(SUPPORT_DET_MODEL)) @@ -277,7 +276,7 @@ def main(): # for cmd args = parse_args(mMain=True) image_dir = args.image_dir - if image_dir.startswith('http'): + if is_link(image_dir): download_with_progressbar(image_dir, 'tmp.jpg') image_file_list = ['tmp.jpg'] else: -- GitLab