提交 7d472831 编写于 作者: W WenmuZhou

support model link in model_dir params

上级 037e17fc
...@@ -28,7 +28,7 @@ from ppocr.utils.logging import get_logger ...@@ -28,7 +28,7 @@ from ppocr.utils.logging import get_logger
logger = get_logger() logger = get_logger()
from ppocr.utils.utility import check_and_read_gif, get_image_file_list from ppocr.utils.utility import check_and_read_gif, get_image_file_list
from ppocr.utils.network import maybe_download, download_with_progressbar from ppocr.utils.network import maybe_download, download_with_progressbar, is_link, confirm_model_dir_url
from tools.infer.utility import draw_ocr, init_args, str2bool from tools.infer.utility import draw_ocr, init_args, str2bool
__all__ = ['PaddleOCR'] __all__ = ['PaddleOCR']
...@@ -192,20 +192,19 @@ class PaddleOCR(predict_system.TextSystem): ...@@ -192,20 +192,19 @@ class PaddleOCR(predict_system.TextSystem):
'dict_path'] 'dict_path']
# init model dir # init model dir
if params.det_model_dir is None: params.det_model_dir, det_url = confirm_model_dir_url(params.det_model_dir,
params.det_model_dir = os.path.join(BASE_DIR, VERSION, os.path.join(BASE_DIR, VERSION, 'det', det_lang),
'det', det_lang) model_urls['det'][det_lang])
if params.rec_model_dir is None: params.rec_model_dir, rec_url = confirm_model_dir_url(params.rec_model_dir,
params.rec_model_dir = os.path.join(BASE_DIR, VERSION, os.path.join(BASE_DIR, VERSION, 'rec', lang),
'rec', lang) model_urls['rec'][lang]['url'])
if params.cls_model_dir is None: params.cls_model_dir, cls_url = confirm_model_dir_url(params.cls_model_dir,
params.cls_model_dir = os.path.join(BASE_DIR, 'cls') os.path.join(BASE_DIR, VERSION, 'cls'),
model_urls['cls'])
# download model # download model
maybe_download(params.det_model_dir, maybe_download(params.det_model_dir, det_url)
model_urls['det'][det_lang]) maybe_download(params.rec_model_dir, rec_url)
maybe_download(params.rec_model_dir, maybe_download(params.cls_model_dir, cls_url)
model_urls['rec'][lang]['url'])
maybe_download(params.cls_model_dir, model_urls['cls'])
if params.det_algorithm not in SUPPORT_DET_MODEL: if params.det_algorithm not in SUPPORT_DET_MODEL:
logger.error('det_algorithm must in {}'.format(SUPPORT_DET_MODEL)) logger.error('det_algorithm must in {}'.format(SUPPORT_DET_MODEL))
...@@ -277,7 +276,7 @@ def main(): ...@@ -277,7 +276,7 @@ def main():
# for cmd # for cmd
args = parse_args(mMain=True) args = parse_args(mMain=True)
image_dir = args.image_dir image_dir = args.image_dir
if image_dir.startswith('http'): if is_link(image_dir):
download_with_progressbar(image_dir, 'tmp.jpg') download_with_progressbar(image_dir, 'tmp.jpg')
image_file_list = ['tmp.jpg'] image_file_list = ['tmp.jpg']
else: else:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册