diff --git a/paddleocr.py b/paddleocr.py index a530d7868523019920652ec4611832e00fc4da3a..f19d1defee217a8d3dc0653e6b8fd1713cb389fa 100644 --- a/paddleocr.py +++ b/paddleocr.py @@ -47,7 +47,9 @@ SUPPORT_REC_MODEL = ['CRNN'] BASE_DIR = os.path.expanduser("~/.paddleocr/") DEFAULT_OCR_MODEL_VERSION = 'PP-OCR' +SUPPORT_OCR_MODEL_VERSION = ['PP-OCR', 'PP-OCRv2'] DEFAULT_STRUCTURE_MODEL_VERSION = 'STRUCTURE' +SUPPORT_STRUCTURE_MODEL_VERSION = ['STRUCTURE'] MODEL_URLS = { 'OCR': { 'PP-OCRv2': { @@ -190,6 +192,7 @@ def parse_args(mMain=True): parser.add_argument( "--ocr_version", type=str, + choices=SUPPORT_OCR_MODEL_VERSION, default='PP-OCRv2', help='OCR Model version, the current model support list is as follows: ' '1. PP-OCRv2 Support Chinese detection and recognition model. ' @@ -198,6 +201,7 @@ def parse_args(mMain=True): parser.add_argument( "--structure_version", type=str, + choices=SUPPORT_STRUCTURE_MODEL_VERSION, default='STRUCTURE', help='Model version, the current model support list is as follows:' ' 1. STRUCTURE Support en table structure model.') @@ -257,26 +261,20 @@ def get_model_config(type, version, model_type, lang): DEFAULT_MODEL_VERSION = DEFAULT_STRUCTURE_MODEL_VERSION else: raise NotImplementedError + model_urls = MODEL_URLS[type] if version not in model_urls: - logger.warning('version {} not in {}, auto switch to version {}'.format( - version, model_urls.keys(), DEFAULT_MODEL_VERSION)) version = DEFAULT_MODEL_VERSION if model_type not in model_urls[version]: if model_type in model_urls[DEFAULT_MODEL_VERSION]: - logger.warning( - 'version {} not support {} models, auto switch to version {}'. - format(version, model_type, DEFAULT_MODEL_VERSION)) version = DEFAULT_MODEL_VERSION else: logger.error('{} models is not support, we only support {}'.format( model_type, model_urls[DEFAULT_MODEL_VERSION].keys())) sys.exit(-1) + if lang not in model_urls[version][model_type]: if lang in model_urls[DEFAULT_MODEL_VERSION][model_type]: - logger.warning( - 'lang {} is not support in {}, auto switch to version {}'. - format(lang, version, DEFAULT_MODEL_VERSION)) version = DEFAULT_MODEL_VERSION else: logger.error( @@ -296,6 +294,8 @@ class PaddleOCR(predict_system.TextSystem): """ params = parse_args(mMain=False) params.__dict__.update(**kwargs) + assert params.ocr_version in SUPPORT_OCR_MODEL_VERSION, "ocr_version must in {}, but get {}".format( + SUPPORT_OCR_MODEL_VERSION, params.ocr_version) params.use_gpu = check_gpu(params.use_gpu) if not params.show_log: @@ -398,6 +398,8 @@ class PPStructure(OCRSystem): def __init__(self, **kwargs): params = parse_args(mMain=False) params.__dict__.update(**kwargs) + assert params.structure_version in SUPPORT_STRUCTURE_MODEL_VERSION, "structure_version must in {}, but get {}".format( + SUPPORT_STRUCTURE_MODEL_VERSION, params.structure_version) params.use_gpu = check_gpu(params.use_gpu) if not params.show_log: