diff --git a/MANIFEST.in b/MANIFEST.in index d674fabc5d714f7e31ed00b32be2d44d6dd10871..1fcf184dacee9dcaf3d5b2e62d12c7b156e068c7 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -5,5 +5,6 @@ recursive-include ppocr/utils *.txt utility.py logging.py network.py recursive-include ppocr/data *.py recursive-include ppocr/postprocess *.py recursive-include tools/infer *.py +recursive-include tools __init__.py recursive-include ppocr/utils/e2e_utils *.py recursive-include ppstructure *.py \ No newline at end of file diff --git a/configs/table/table_mv3.yml b/configs/table/table_mv3.yml index a74e18d318699685400cc48430c04db3fef70b60..1a91ea95afb4ff91d3fd68fe0df6afaac9304661 100755 --- a/configs/table/table_mv3.yml +++ b/configs/table/table_mv3.yml @@ -1,29 +1,28 @@ Global: use_gpu: true - epoch_num: 50 + epoch_num: 400 log_smooth_window: 20 print_batch_step: 5 save_model_dir: ./output/table_mv3/ - save_epoch_step: 5 + save_epoch_step: 3 # evaluation is run every 400 iterations after the 0th iteration eval_batch_step: [0, 400] cal_metric_during_train: True - pretrained_model: + pretrained_model: checkpoints: save_inference_dir: use_visualdl: False - infer_img: doc/imgs_words/ch/word_1.jpg + infer_img: doc/table/table.jpg # for data or label process character_dict_path: ppocr/utils/dict/table_structure_dict.txt character_type: en max_text_length: 100 - max_elem_length: 500 + max_elem_length: 800 max_cell_num: 500 infer_mode: False process_total_num: 0 process_cut_num: 0 - Optimizer: name: Adam beta1: 0.9 @@ -41,13 +40,15 @@ Architecture: Backbone: name: MobileNetV3 scale: 1.0 - model_name: small - disable_se: True + model_name: large Head: name: TableAttentionHead hidden_size: 256 l2_decay: 0.00001 loc_type: 2 + max_text_length: 100 + max_elem_length: 800 + max_cell_num: 500 Loss: name: TableAttentionLoss diff --git a/deploy/pdserving/pipeline_rpc_client.py b/deploy/pdserving/pipeline_rpc_client.py index 4dcb1ad5f533729e344809e99951b59fb2908537..3d2a90f443f76ba142bbae05e00ea76b083335ba 100644 --- a/deploy/pdserving/pipeline_rpc_client.py +++ b/deploy/pdserving/pipeline_rpc_client.py @@ -41,6 +41,6 @@ for img_file in os.listdir(test_img_dir): image_data = file.read() image = cv2_to_base64(image_data) -for i in range(1): - ret = client.predict(feed_dict={"image": image}, fetch=["res"]) - print(ret) + for i in range(1): + ret = client.predict(feed_dict={"image": image}, fetch=["res"]) + print(ret) diff --git a/doc/doc_ch/detection.md b/doc/doc_ch/detection.md index 8db64664f6ff560450a5ee99d708313c931989fc..e69f7355390f5c16cb1b498389a6552b5d40ca0a 100644 --- a/doc/doc_ch/detection.md +++ b/doc/doc_ch/detection.md @@ -96,7 +96,7 @@ python3 tools/train.py -c configs/det/det_mv3_db.yml \ # 单机多卡训练,通过 --gpus 参数设置使用的GPU ID python3 -m paddle.distributed.launch --gpus '0,1,2,3' tools/train.py -c configs/det/det_mv3_db.yml \ -o Global.pretrained_model=./pretrain_models/MobileNetV3_large_x0_5_pretrained - + # 多机多卡训练,通过 --ips 参数设置使用的机器IP地址,通过 --gpus 参数设置使用的GPU ID python3 -m paddle.distributed.launch --ips="xx.xx.xx.xx,xx.xx.xx.xx" --gpus '0,1,2,3' tools/train.py -c configs/det/det_mv3_db.yml \ -o Global.pretrained_model=./pretrain_models/MobileNetV3_large_x0_5_pretrained @@ -104,14 +104,14 @@ python3 -m paddle.distributed.launch --ips="xx.xx.xx.xx,xx.xx.xx.xx" --gpus '0,1 上述指令中,通过-c 选择训练使用configs/det/det_db_mv3.yml配置文件。 有关配置文件的详细解释,请参考[链接](./config.md)。 - + 您也可以通过-o参数在不需要修改yml文件的情况下,改变训练的参数,比如,调整训练的学习率为0.0001 ```shell python3 tools/train.py -c configs/det/det_mv3_db.yml -o Optimizer.base_lr=0.0001 ``` - -**注意:** 采用多机多卡训练时,需要替换上面命令中的ips值为您机器的地址,机器之间需要能够相互ping通。查看机器ip地址的命令为`ifconfig`。 - + +**注意:** 采用多机多卡训练时,需要替换上面命令中的ips值为您机器的地址,机器之间需要能够相互ping通。另外,训练时需要在多个机器上分别启动命令。查看机器ip地址的命令为`ifconfig`。 + 如果您想进一步加快训练速度,可以使用[自动混合精度训练](https://www.paddlepaddle.org.cn/documentation/docs/zh/guides/01_paddle2.0_introduction/basic_concept/amp_cn.html), 以单机单卡为例,命令如下: ```shell python3 tools/train.py -c configs/det/det_mv3_db.yml \ diff --git a/doc/doc_ch/whl.md b/doc/doc_ch/whl.md index ba5bbae6255382d0c7fa5be319946d6242b1a544..3b88709a2328409a266d0d482baa072dd7aa3824 100644 --- a/doc/doc_ch/whl.md +++ b/doc/doc_ch/whl.md @@ -420,3 +420,5 @@ im_show.save('result.jpg') | cls | 前向时是否启动分类 (命令行模式下使用use_angle_cls控制前向是否启动分类) | FALSE | | show_log | 是否打印det和rec等信息 | FALSE | | type | 执行ocr或者表格结构化, 值可选['ocr','structure'] | ocr | +| ocr_version | OCR模型版本,可选PP-OCRv2, PP-OCR。PP-OCRv2 目前仅支持中文的检测和识别模型,PP-OCR支持中文的检测,识别,多语种识别,方向分类器等模型 | PP-OCRv2 | +| structure_version | 表格结构化模型版本,可选 STRUCTURE。STRUCTURE支持表格结构化模型 | STRUCTURE | diff --git a/doc/doc_en/detection_en.md b/doc/doc_en/detection_en.md index 948733e16cebea2ce819367a863948434ece5ae5..586ed30bb841122717e66966337b5c99b9cf3397 100644 --- a/doc/doc_en/detection_en.md +++ b/doc/doc_en/detection_en.md @@ -98,14 +98,14 @@ python3 tools/train.py -c configs/det/det_mv3_db.yml -o \ # multi-GPU training # Set the GPU ID used by the '--gpus' parameter. python3 -m paddle.distributed.launch --gpus '0,1,2,3' tools/train.py -c configs/det/det_mv3_db.yml -o Global.pretrained_model=./pretrain_models/MobileNetV3_large_x0_5_pretrained - + # multi-Node, multi-GPU training # Set the IPs of your nodes used by the '--ips' parameter. Set the GPU ID used by the '--gpus' parameter. python3 -m paddle.distributed.launch --ips="xx.xx.xx.xx,xx.xx.xx.xx" --gpus '0,1,2,3' tools/train.py -c configs/det/det_mv3_db.yml \ -o Global.pretrained_model=./pretrain_models/MobileNetV3_large_x0_5_pretrained ``` -**Note:** For multi-Node multi-GPU training, you need to replace the `ips` value in the preceding command with the address of your machine, and the machines must be able to ping each other. The command for viewing the IP address of the machine is `ifconfig`. - +**Note:** For multi-Node multi-GPU training, you need to replace the `ips` value in the preceding command with the address of your machine, and the machines must be able to ping each other. In addition, it requires activating commands separately on multiple machines when we start the training. The command for viewing the IP address of the machine is `ifconfig`. + If you want to further speed up the training, you can use [automatic mixed precision training](https://www.paddlepaddle.org.cn/documentation/docs/zh/guides/01_paddle2.0_introduction/basic_concept/amp_en.html). for single card training, the command is as follows: ``` python3 tools/train.py -c configs/det/det_mv3_db.yml \ diff --git a/doc/doc_en/whl_en.md b/doc/doc_en/whl_en.md index c2577e1e151e4675abab5139da099db9ad20fb4b..62aa452dcd36906c6480031375e6ca94f8a36de3 100644 --- a/doc/doc_en/whl_en.md +++ b/doc/doc_en/whl_en.md @@ -366,4 +366,6 @@ im_show.save('result.jpg') | rec | Enable recognition when `ppocr.ocr` func exec | TRUE | | cls | Enable classification when `ppocr.ocr` func exec((Use use_angle_cls in command line mode to control whether to start classification in the forward direction) | FALSE | | show_log | Whether to print log in det and rec | FALSE | -| type | Perform ocr or table structuring, the value is selected in ['ocr','structure'] | ocr | \ No newline at end of file +| type | Perform ocr or table structuring, the value is selected in ['ocr','structure'] | ocr | +| ocr_version | OCR Model version number, the current model support list is as follows: PP-OCRv2 support Chinese detection and recognition model, PP-OCR support Chinese detection, recognition and direction classifier, multilingual recognition model | PP-OCRv2 | +| structure_version | table structure Model version number, the current model support list is as follows: STRUCTURE support english table structure model | STRUCTURE | diff --git a/paddleocr.py b/paddleocr.py index a98efd34088701d5eb5602743cf75b7d5e80157f..028cfcc1faae3d9cca7d756b55213c030c7496de 100644 --- a/paddleocr.py +++ b/paddleocr.py @@ -16,6 +16,9 @@ import os import sys __dir__ = os.path.dirname(__file__) + +import paddle + sys.path.append(os.path.join(__dir__, '')) import cv2 @@ -29,7 +32,7 @@ from ppocr.utils.logging import get_logger logger = get_logger() from ppocr.utils.utility import check_and_read_gif, get_image_file_list from ppocr.utils.network import maybe_download, download_with_progressbar, is_link, confirm_model_dir_url -from tools.infer.utility import draw_ocr, str2bool +from tools.infer.utility import draw_ocr, str2bool, check_gpu from ppstructure.utility import init_args, draw_structure_result from ppstructure.predict_system import OCRSystem, save_structure_res @@ -39,130 +42,137 @@ __all__ = [ ] SUPPORT_DET_MODEL = ['DB'] -VERSION = '2.2.1' +VERSION = '2.3.0.1' SUPPORT_REC_MODEL = ['CRNN'] BASE_DIR = os.path.expanduser("~/.paddleocr/") -DEFAULT_MODEL_VERSION = '2.0' +DEFAULT_OCR_MODEL_VERSION = 'PP-OCR' +DEFAULT_STRUCTURE_MODEL_VERSION = 'STRUCTURE' MODEL_URLS = { - '2.1': { - 'det': { - 'ch': { - 'url': - 'https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_det_infer.tar', - }, - }, - 'rec': { - 'ch': { - 'url': - 'https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_rec_infer.tar', - 'dict_path': './ppocr/utils/ppocr_keys_v1.txt' - } - } - }, - '2.0': { - 'det': { - 'ch': { - 'url': - 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_det_infer.tar', - }, - 'en': { - 'url': - 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/en_ppocr_mobile_v2.0_det_infer.tar', + 'OCR': { + 'PP-OCRv2': { + 'det': { + 'ch': { + 'url': + 'https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_det_infer.tar', + }, }, - 'structure': { - 'url': - 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_det_infer.tar' + 'rec': { + 'ch': { + 'url': + 'https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_rec_infer.tar', + 'dict_path': './ppocr/utils/ppocr_keys_v1.txt' + } } }, - 'rec': { - 'ch': { - 'url': - 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_rec_infer.tar', - 'dict_path': './ppocr/utils/ppocr_keys_v1.txt' - }, - 'en': { - 'url': - 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/en_number_mobile_v2.0_rec_infer.tar', - 'dict_path': './ppocr/utils/en_dict.txt' - }, - 'french': { - 'url': - 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/french_mobile_v2.0_rec_infer.tar', - 'dict_path': './ppocr/utils/dict/french_dict.txt' - }, - 'german': { - 'url': - 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/german_mobile_v2.0_rec_infer.tar', - 'dict_path': './ppocr/utils/dict/german_dict.txt' - }, - 'korean': { - 'url': - 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/korean_mobile_v2.0_rec_infer.tar', - 'dict_path': './ppocr/utils/dict/korean_dict.txt' - }, - 'japan': { - 'url': - 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/japan_mobile_v2.0_rec_infer.tar', - 'dict_path': './ppocr/utils/dict/japan_dict.txt' - }, - 'chinese_cht': { - 'url': - 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/chinese_cht_mobile_v2.0_rec_infer.tar', - 'dict_path': './ppocr/utils/dict/chinese_cht_dict.txt' - }, - 'ta': { - 'url': - 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/ta_mobile_v2.0_rec_infer.tar', - 'dict_path': './ppocr/utils/dict/ta_dict.txt' + DEFAULT_OCR_MODEL_VERSION: { + 'det': { + 'ch': { + 'url': + 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_det_infer.tar', + }, + 'en': { + 'url': + 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/en_ppocr_mobile_v2.0_det_infer.tar', + }, + 'structure': { + 'url': + 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_det_infer.tar' + } }, - 'te': { - 'url': - 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/te_mobile_v2.0_rec_infer.tar', - 'dict_path': './ppocr/utils/dict/te_dict.txt' + 'rec': { + 'ch': { + 'url': + 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_rec_infer.tar', + 'dict_path': './ppocr/utils/ppocr_keys_v1.txt' + }, + 'en': { + 'url': + 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/en_number_mobile_v2.0_rec_infer.tar', + 'dict_path': './ppocr/utils/en_dict.txt' + }, + 'french': { + 'url': + 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/french_mobile_v2.0_rec_infer.tar', + 'dict_path': './ppocr/utils/dict/french_dict.txt' + }, + 'german': { + 'url': + 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/german_mobile_v2.0_rec_infer.tar', + 'dict_path': './ppocr/utils/dict/german_dict.txt' + }, + 'korean': { + 'url': + 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/korean_mobile_v2.0_rec_infer.tar', + 'dict_path': './ppocr/utils/dict/korean_dict.txt' + }, + 'japan': { + 'url': + 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/japan_mobile_v2.0_rec_infer.tar', + 'dict_path': './ppocr/utils/dict/japan_dict.txt' + }, + 'chinese_cht': { + 'url': + 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/chinese_cht_mobile_v2.0_rec_infer.tar', + 'dict_path': './ppocr/utils/dict/chinese_cht_dict.txt' + }, + 'ta': { + 'url': + 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/ta_mobile_v2.0_rec_infer.tar', + 'dict_path': './ppocr/utils/dict/ta_dict.txt' + }, + 'te': { + 'url': + 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/te_mobile_v2.0_rec_infer.tar', + 'dict_path': './ppocr/utils/dict/te_dict.txt' + }, + 'ka': { + 'url': + 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/ka_mobile_v2.0_rec_infer.tar', + 'dict_path': './ppocr/utils/dict/ka_dict.txt' + }, + 'latin': { + 'url': + 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/latin_ppocr_mobile_v2.0_rec_infer.tar', + 'dict_path': './ppocr/utils/dict/latin_dict.txt' + }, + 'arabic': { + 'url': + 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/arabic_ppocr_mobile_v2.0_rec_infer.tar', + 'dict_path': './ppocr/utils/dict/arabic_dict.txt' + }, + 'cyrillic': { + 'url': + 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/cyrillic_ppocr_mobile_v2.0_rec_infer.tar', + 'dict_path': './ppocr/utils/dict/cyrillic_dict.txt' + }, + 'devanagari': { + 'url': + 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/devanagari_ppocr_mobile_v2.0_rec_infer.tar', + 'dict_path': './ppocr/utils/dict/devanagari_dict.txt' + }, + 'structure': { + 'url': + 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_rec_infer.tar', + 'dict_path': 'ppocr/utils/dict/table_dict.txt' + } }, - 'ka': { - 'url': - 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/ka_mobile_v2.0_rec_infer.tar', - 'dict_path': './ppocr/utils/dict/ka_dict.txt' + 'cls': { + 'ch': { + 'url': + 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_cls_infer.tar', + } }, - 'latin': { - 'url': - 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/latin_ppocr_mobile_v2.0_rec_infer.tar', - 'dict_path': './ppocr/utils/dict/latin_dict.txt' - }, - 'arabic': { - 'url': - 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/arabic_ppocr_mobile_v2.0_rec_infer.tar', - 'dict_path': './ppocr/utils/dict/arabic_dict.txt' - }, - 'cyrillic': { - 'url': - 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/cyrillic_ppocr_mobile_v2.0_rec_infer.tar', - 'dict_path': './ppocr/utils/dict/cyrillic_dict.txt' - }, - 'devanagari': { - 'url': - 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/devanagari_ppocr_mobile_v2.0_rec_infer.tar', - 'dict_path': './ppocr/utils/dict/devanagari_dict.txt' - }, - 'structure': { - 'url': - 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_rec_infer.tar', - 'dict_path': 'ppocr/utils/dict/table_dict.txt' - } - }, - 'cls': { - 'ch': { - 'url': - 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_cls_infer.tar', - } - }, - 'table': { - 'en': { - 'url': - 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_structure_infer.tar', - 'dict_path': 'ppocr/utils/dict/table_structure_dict.txt' + } + }, + 'STRUCTURE': { + DEFAULT_STRUCTURE_MODEL_VERSION: { + 'table': { + 'en': { + 'url': + 'https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_structure_infer.tar', + 'dict_path': 'ppocr/utils/dict/table_structure_dict.txt' + } } } } @@ -177,7 +187,20 @@ def parse_args(mMain=True): parser.add_argument("--det", type=str2bool, default=True) parser.add_argument("--rec", type=str2bool, default=True) parser.add_argument("--type", type=str, default='ocr') - parser.add_argument("--version", type=str, default='2.1') + parser.add_argument( + "--ocr_version", + type=str, + default='PP-OCRv2', + help='OCR Model version, the current model support list is as follows: ' + '1. PP-OCRv2 Support Chinese detection and recognition model. ' + '2. PP-OCR support Chinese detection, recognition and direction classifier and multilingual recognition model.' + ) + parser.add_argument( + "--structure_version", + type=str, + default='STRUCTURE', + help='Model version, the current model support list is as follows:' + ' 1. STRUCTURE Support en table structure model.') for action in parser._actions: if action.dest in ['rec_char_dict_path', 'table_char_dict_path']: @@ -215,9 +238,9 @@ def parse_lang(lang): lang = "cyrillic" elif lang in devanagari_lang: lang = "devanagari" - assert lang in MODEL_URLS[DEFAULT_MODEL_VERSION][ + assert lang in MODEL_URLS['OCR'][DEFAULT_OCR_MODEL_VERSION][ 'rec'], 'param lang must in {}, but got {}'.format( - MODEL_URLS[DEFAULT_MODEL_VERSION]['rec'].keys(), lang) + MODEL_URLS['OCR'][DEFAULT_OCR_MODEL_VERSION]['rec'].keys(), lang) if lang == "ch": det_lang = "ch" elif lang == 'structure': @@ -227,33 +250,41 @@ def parse_lang(lang): return lang, det_lang -def get_model_config(version, model_type, lang): - if version not in MODEL_URLS: - logger.warning('version {} not in {}, use version {} instead'.format( - version, MODEL_URLS.keys(), DEFAULT_MODEL_VERSION)) +def get_model_config(type, version, model_type, lang): + if type == 'OCR': + DEFAULT_MODEL_VERSION = DEFAULT_OCR_MODEL_VERSION + elif type == 'STRUCTURE': + DEFAULT_MODEL_VERSION = DEFAULT_STRUCTURE_MODEL_VERSION + else: + raise NotImplementedError + model_urls = MODEL_URLS[type] + if version not in model_urls: + logger.warning('version {} not in {}, auto switch to version {}'.format( + version, model_urls.keys(), DEFAULT_MODEL_VERSION)) version = DEFAULT_MODEL_VERSION - if model_type not in MODEL_URLS[version]: - if model_type in MODEL_URLS[DEFAULT_MODEL_VERSION]: + if model_type not in model_urls[version]: + if model_type in model_urls[DEFAULT_MODEL_VERSION]: logger.warning( - 'version {} not support {} models, use version {} instead'. + 'version {} not support {} models, auto switch to version {}'. format(version, model_type, DEFAULT_MODEL_VERSION)) version = DEFAULT_MODEL_VERSION else: logger.error('{} models is not support, we only support {}'.format( - model_type, MODEL_URLS[DEFAULT_MODEL_VERSION].keys())) + model_type, model_urls[DEFAULT_MODEL_VERSION].keys())) sys.exit(-1) - if lang not in MODEL_URLS[version][model_type]: - if lang in MODEL_URLS[DEFAULT_MODEL_VERSION][model_type]: - logger.warning('lang {} is not support in {}, use {} instead'. - format(lang, version, DEFAULT_MODEL_VERSION)) + if lang not in model_urls[version][model_type]: + if lang in model_urls[DEFAULT_MODEL_VERSION][model_type]: + logger.warning( + 'lang {} is not support in {}, auto switch to version {}'. + format(lang, version, DEFAULT_MODEL_VERSION)) version = DEFAULT_MODEL_VERSION else: logger.error( 'lang {} is not support, we only support {} for {} models'. - format(lang, MODEL_URLS[DEFAULT_MODEL_VERSION][model_type].keys( + format(lang, model_urls[DEFAULT_MODEL_VERSION][model_type].keys( ), model_type)) sys.exit(-1) - return MODEL_URLS[version][model_type][lang] + return model_urls[version][model_type][lang] class PaddleOCR(predict_system.TextSystem): @@ -265,23 +296,28 @@ class PaddleOCR(predict_system.TextSystem): """ params = parse_args(mMain=False) params.__dict__.update(**kwargs) + params.use_gpu = check_gpu(params.use_gpu) + if not params.show_log: logger.setLevel(logging.INFO) self.use_angle_cls = params.use_angle_cls lang, det_lang = parse_lang(params.lang) # init model dir - det_model_config = get_model_config(params.version, 'det', det_lang) + det_model_config = get_model_config('OCR', params.ocr_version, 'det', + det_lang) params.det_model_dir, det_url = confirm_model_dir_url( params.det_model_dir, os.path.join(BASE_DIR, VERSION, 'ocr', 'det', det_lang), det_model_config['url']) - rec_model_config = get_model_config(params.version, 'rec', lang) + rec_model_config = get_model_config('OCR', params.ocr_version, 'rec', + lang) params.rec_model_dir, rec_url = confirm_model_dir_url( params.rec_model_dir, os.path.join(BASE_DIR, VERSION, 'ocr', 'rec', lang), rec_model_config['url']) - cls_model_config = get_model_config(params.version, 'cls', 'ch') + cls_model_config = get_model_config('OCR', params.ocr_version, 'cls', + 'ch') params.cls_model_dir, cls_url = confirm_model_dir_url( params.cls_model_dir, os.path.join(BASE_DIR, VERSION, 'ocr', 'cls'), @@ -362,22 +398,27 @@ class PPStructure(OCRSystem): def __init__(self, **kwargs): params = parse_args(mMain=False) params.__dict__.update(**kwargs) + params.use_gpu = check_gpu(params.use_gpu) + if not params.show_log: logger.setLevel(logging.INFO) lang, det_lang = parse_lang(params.lang) # init model dir - det_model_config = get_model_config(params.version, 'det', det_lang) + det_model_config = get_model_config('OCR', params.ocr_version, 'det', + det_lang) params.det_model_dir, det_url = confirm_model_dir_url( params.det_model_dir, os.path.join(BASE_DIR, VERSION, 'ocr', 'det', det_lang), det_model_config['url']) - rec_model_config = get_model_config(params.version, 'rec', lang) + rec_model_config = get_model_config('OCR', params.ocr_version, 'rec', + lang) params.rec_model_dir, rec_url = confirm_model_dir_url( params.rec_model_dir, os.path.join(BASE_DIR, VERSION, 'ocr', 'rec', lang), rec_model_config['url']) - table_model_config = get_model_config(params.version, 'table', 'en') + table_model_config = get_model_config( + 'STRUCTURE', params.structure_version, 'table', 'en') params.table_model_dir, table_url = confirm_model_dir_url( params.table_model_dir, os.path.join(BASE_DIR, VERSION, 'ocr', 'table'), diff --git a/ppocr/modeling/heads/table_att_head.py b/ppocr/modeling/heads/table_att_head.py index 155f036d15673135eae9e5ee493648603609535d..e354f40d6518c1f7ca22e93694b1c6668fc003d2 100644 --- a/ppocr/modeling/heads/table_att_head.py +++ b/ppocr/modeling/heads/table_att_head.py @@ -23,32 +23,40 @@ import numpy as np class TableAttentionHead(nn.Layer): - def __init__(self, in_channels, hidden_size, loc_type, in_max_len=488, **kwargs): + def __init__(self, + in_channels, + hidden_size, + loc_type, + in_max_len=488, + max_text_length=100, + max_elem_length=800, + max_cell_num=500, + **kwargs): super(TableAttentionHead, self).__init__() self.input_size = in_channels[-1] self.hidden_size = hidden_size self.elem_num = 30 - self.max_text_length = 100 - self.max_elem_length = 500 - self.max_cell_num = 500 + self.max_text_length = max_text_length + self.max_elem_length = max_elem_length + self.max_cell_num = max_cell_num self.structure_attention_cell = AttentionGRUCell( self.input_size, hidden_size, self.elem_num, use_gru=False) self.structure_generator = nn.Linear(hidden_size, self.elem_num) self.loc_type = loc_type self.in_max_len = in_max_len - + if self.loc_type == 1: self.loc_generator = nn.Linear(hidden_size, 4) else: if self.in_max_len == 640: - self.loc_fea_trans = nn.Linear(400, self.max_elem_length+1) + self.loc_fea_trans = nn.Linear(400, self.max_elem_length + 1) elif self.in_max_len == 800: - self.loc_fea_trans = nn.Linear(625, self.max_elem_length+1) + self.loc_fea_trans = nn.Linear(625, self.max_elem_length + 1) else: - self.loc_fea_trans = nn.Linear(256, self.max_elem_length+1) + self.loc_fea_trans = nn.Linear(256, self.max_elem_length + 1) self.loc_generator = nn.Linear(self.input_size + hidden_size, 4) - + def _char_to_onehot(self, input_char, onehot_dim): input_ont_hot = F.one_hot(input_char, onehot_dim) return input_ont_hot @@ -60,16 +68,16 @@ class TableAttentionHead(nn.Layer): if len(fea.shape) == 3: pass else: - last_shape = int(np.prod(fea.shape[2:])) # gry added + last_shape = int(np.prod(fea.shape[2:])) # gry added fea = paddle.reshape(fea, [fea.shape[0], fea.shape[1], last_shape]) fea = fea.transpose([0, 2, 1]) # (NTC)(batch, width, channels) batch_size = fea.shape[0] - + hidden = paddle.zeros((batch_size, self.hidden_size)) output_hiddens = [] if self.training and targets is not None: structure = targets[0] - for i in range(self.max_elem_length+1): + for i in range(self.max_elem_length + 1): elem_onehots = self._char_to_onehot( structure[:, i], onehot_dim=self.elem_num) (outputs, hidden), alpha = self.structure_attention_cell( @@ -96,7 +104,7 @@ class TableAttentionHead(nn.Layer): alpha = None max_elem_length = paddle.to_tensor(self.max_elem_length) i = 0 - while i < max_elem_length+1: + while i < max_elem_length + 1: elem_onehots = self._char_to_onehot( temp_elem, onehot_dim=self.elem_num) (outputs, hidden), alpha = self.structure_attention_cell( @@ -105,7 +113,7 @@ class TableAttentionHead(nn.Layer): structure_probs_step = self.structure_generator(outputs) temp_elem = structure_probs_step.argmax(axis=1, dtype="int32") i += 1 - + output = paddle.concat(output_hiddens, axis=1) structure_probs = self.structure_generator(output) structure_probs = F.softmax(structure_probs) @@ -119,9 +127,9 @@ class TableAttentionHead(nn.Layer): loc_concat = paddle.concat([output, loc_fea], axis=2) loc_preds = self.loc_generator(loc_concat) loc_preds = F.sigmoid(loc_preds) - return {'structure_probs':structure_probs, 'loc_preds':loc_preds} + return {'structure_probs': structure_probs, 'loc_preds': loc_preds} + - class AttentionGRUCell(nn.Layer): def __init__(self, input_size, hidden_size, num_embeddings, use_gru=False): super(AttentionGRUCell, self).__init__() diff --git a/ppocr/utils/network.py b/ppocr/utils/network.py index 453abb693d4c0ed370c1031b677d5bf51661add9..118d1be364925d9416134cffe21d636fcac753e9 100644 --- a/ppocr/utils/network.py +++ b/ppocr/utils/network.py @@ -24,15 +24,17 @@ from ppocr.utils.logging import get_logger def download_with_progressbar(url, save_path): logger = get_logger() response = requests.get(url, stream=True) - total_size_in_bytes = int(response.headers.get('content-length', 0)) - block_size = 1024 # 1 Kibibyte - progress_bar = tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True) - with open(save_path, 'wb') as file: - for data in response.iter_content(block_size): - progress_bar.update(len(data)) - file.write(data) - progress_bar.close() - if total_size_in_bytes == 0 or progress_bar.n != total_size_in_bytes: + if response.status_code == 200: + total_size_in_bytes = int(response.headers.get('content-length', 1)) + block_size = 1024 # 1 Kibibyte + progress_bar = tqdm( + total=total_size_in_bytes, unit='iB', unit_scale=True) + with open(save_path, 'wb') as file: + for data in response.iter_content(block_size): + progress_bar.update(len(data)) + file.write(data) + progress_bar.close() + else: logger.error("Something went wrong while downloading models") sys.exit(0) @@ -45,7 +47,7 @@ def maybe_download(model_storage_directory, url): if not os.path.exists( os.path.join(model_storage_directory, 'inference.pdiparams') ) or not os.path.exists( - os.path.join(model_storage_directory, 'inference.pdmodel')): + os.path.join(model_storage_directory, 'inference.pdmodel')): assert url.endswith('.tar'), 'Only supports tar compressed package' tmp_path = os.path.join(model_storage_directory, url.split('/')[-1]) print('download {} to {}'.format(url, tmp_path)) diff --git a/test_tipc/test_paddle2onnx.sh b/test_tipc/test_paddle2onnx.sh index cd3cfc6120be631a21944a49ba8ffa6f92b5fc38..300c61770d2519fad0502147e2cee4a3e4f50ac9 100644 --- a/test_tipc/test_paddle2onnx.sh +++ b/test_tipc/test_paddle2onnx.sh @@ -16,26 +16,28 @@ IFS=$'\n' lines=(${dataline}) # parser paddle2onnx -padlle2onnx_cmd=$(func_parser_value "${lines[1]}") -infer_model_dir_key=$(func_parser_key "${lines[2]}") -infer_model_dir_value=$(func_parser_value "${lines[2]}") -model_filename_key=$(func_parser_key "${lines[3]}") -model_filename_value=$(func_parser_value "${lines[3]}") -params_filename_key=$(func_parser_key "${lines[4]}") -params_filename_value=$(func_parser_value "${lines[4]}") -save_file_key=$(func_parser_key "${lines[5]}") -save_file_value=$(func_parser_value "${lines[5]}") -opset_version_key=$(func_parser_key "${lines[6]}") -opset_version_value=$(func_parser_value "${lines[6]}") -enable_onnx_checker_key=$(func_parser_key "${lines[7]}") -enable_onnx_checker_value=$(func_parser_value "${lines[7]}") +model_name=$(func_parser_value "${lines[1]}") +python=$(func_parser_value "${lines[2]}") +padlle2onnx_cmd=$(func_parser_value "${lines[3]}") +infer_model_dir_key=$(func_parser_key "${lines[4]}") +infer_model_dir_value=$(func_parser_value "${lines[4]}") +model_filename_key=$(func_parser_key "${lines[5]}") +model_filename_value=$(func_parser_value "${lines[5]}") +params_filename_key=$(func_parser_key "${lines[6]}") +params_filename_value=$(func_parser_value "${lines[6]}") +save_file_key=$(func_parser_key "${lines[7]}") +save_file_value=$(func_parser_value "${lines[7]}") +opset_version_key=$(func_parser_key "${lines[8]}") +opset_version_value=$(func_parser_value "${lines[8]}") +enable_onnx_checker_key=$(func_parser_key "${lines[9]}") +enable_onnx_checker_value=$(func_parser_value "${lines[9]}") # parser onnx inference -inference_py=$(func_parser_value "${lines[8]}") -use_gpu_key=$(func_parser_key "${lines[9]}") -use_gpu_value=$(func_parser_value "${lines[9]}") -det_model_key=$(func_parser_key "${lines[10]}") -image_dir_key=$(func_parser_key "${lines[11]}") -image_dir_value=$(func_parser_value "${lines[11]}") +inference_py=$(func_parser_value "${lines[10]}") +use_gpu_key=$(func_parser_key "${lines[11]}") +use_gpu_value=$(func_parser_value "${lines[11]}") +det_model_key=$(func_parser_key "${lines[12]}") +image_dir_key=$(func_parser_key "${lines[13]}") +image_dir_value=$(func_parser_value "${lines[13]}") LOG_PATH="./test_tipc/output" diff --git a/tools/__init__.py b/tools/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d56c9dbaa1304b160521da03c05db2352e341bf2 --- /dev/null +++ b/tools/__init__.py @@ -0,0 +1,14 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tools/infer/utility.py b/tools/infer/utility.py index 98bf0362f6842cb70490e6817ef53fe39109f406..cab918419ab5efbc4a8a11d1669ca6b93e45e789 100755 --- a/tools/infer/utility.py +++ b/tools/infer/utility.py @@ -17,7 +17,7 @@ import os import sys import cv2 import numpy as np -import json +import paddle from PIL import Image, ImageDraw, ImageFont import math from paddle import inference @@ -601,5 +601,12 @@ def get_rotate_crop_image(img, points): return dst_img +def check_gpu(use_gpu): + if use_gpu and not paddle.is_compiled_with_cuda(): + + use_gpu = False + return use_gpu + + if __name__ == '__main__': pass