diff --git a/paddleocr.py b/paddleocr.py index d52a88b3778ef3924566e0569a7f86d2da578eec..bc9b4e276c38eef127c3af615e69a6c61e41b928 100644 --- a/paddleocr.py +++ b/paddleocr.py @@ -47,7 +47,9 @@ SUPPORT_REC_MODEL = ['CRNN'] BASE_DIR = os.path.expanduser("~/.paddleocr/") DEFAULT_OCR_MODEL_VERSION = 'PP-OCR' +SUPPORT_OCR_MODEL_VERSION = ['PP-OCR', 'PP-OCRv2'] DEFAULT_STRUCTURE_MODEL_VERSION = 'STRUCTURE' +SUPPORT_STRUCTURE_MODEL_VERSION = ['STRUCTURE'] MODEL_URLS = { 'OCR': { 'PP-OCRv2': { @@ -190,7 +192,7 @@ def parse_args(mMain=True): parser.add_argument( "--ocr_version", type=str, - choices=['PP-OCR', 'PP-OCRv2'], + choices=SUPPORT_OCR_MODEL_VERSION, default='PP-OCRv2', help='OCR Model version, the current model support list is as follows: ' '1. PP-OCRv2 Support Chinese detection and recognition model. ' @@ -199,7 +201,7 @@ def parse_args(mMain=True): parser.add_argument( "--structure_version", type=str, - choices=['STRUCTURE'], + choices=SUPPORT_STRUCTURE_MODEL_VERSION, default='STRUCTURE', help='Model version, the current model support list is as follows:' ' 1. STRUCTURE Support en table structure model.') @@ -292,7 +294,7 @@ class PaddleOCR(predict_system.TextSystem): """ params = parse_args(mMain=False) params.__dict__.update(**kwargs) - assert params.ocr_version in ['PP-OCR', 'PP-OCRv2'] + assert params.ocr_version in SUPPORT_OCR_MODEL_VERSION, "ocr_version must in {}, but get {}".format(SUPPORT_OCR_MODEL_VERSION,params.ocr_version) params.use_gpu = check_gpu(params.use_gpu) if not params.show_log: @@ -395,7 +397,7 @@ class PPStructure(OCRSystem): def __init__(self, **kwargs): params = parse_args(mMain=False) params.__dict__.update(**kwargs) - assert params.structure_version in ['STRUCTURE'] + assert params.structure_version in SUPPORT_STRUCTURE_MODEL_VERSION, "ocr_version must in {}, but get {}".format(SUPPORT_STRUCTURE_MODEL_VERSION,params.structure_version) params.use_gpu = check_gpu(params.use_gpu) if not params.show_log: