提交 254ed0ba 编写于 作者: L LDOUBLEV

Merge branch 'dygraph' of https://github.com/PaddlePaddle/PaddleOCR into test_v10

......@@ -5,5 +5,6 @@ recursive-include ppocr/utils *.txt utility.py logging.py network.py
recursive-include ppocr/data *.py
recursive-include ppocr/postprocess *.py
recursive-include tools/infer *.py
recursive-include tools __init__.py
recursive-include ppocr/utils/e2e_utils *.py
recursive-include ppstructure *.py
\ No newline at end of file
Global:
use_gpu: true
epoch_num: 50
epoch_num: 400
log_smooth_window: 20
print_batch_step: 5
save_model_dir: ./output/table_mv3/
save_epoch_step: 5
save_epoch_step: 3
# evaluation is run every 400 iterations after the 0th iteration
eval_batch_step: [0, 400]
cal_metric_during_train: True
pretrained_model:
pretrained_model:
checkpoints:
save_inference_dir:
use_visualdl: False
infer_img: doc/imgs_words/ch/word_1.jpg
infer_img: doc/table/table.jpg
# for data or label process
character_dict_path: ppocr/utils/dict/table_structure_dict.txt
character_type: en
max_text_length: 100
max_elem_length: 500
max_elem_length: 800
max_cell_num: 500
infer_mode: False
process_total_num: 0
process_cut_num: 0
Optimizer:
name: Adam
beta1: 0.9
......@@ -41,13 +40,15 @@ Architecture:
Backbone:
name: MobileNetV3
scale: 1.0
model_name: small
disable_se: True
model_name: large
Head:
name: TableAttentionHead
hidden_size: 256
l2_decay: 0.00001
loc_type: 2
max_text_length: 100
max_elem_length: 800
max_cell_num: 500
Loss:
name: TableAttentionLoss
......
......@@ -41,6 +41,6 @@ for img_file in os.listdir(test_img_dir):
image_data = file.read()
image = cv2_to_base64(image_data)
for i in range(1):
ret = client.predict(feed_dict={"image": image}, fetch=["res"])
print(ret)
for i in range(1):
ret = client.predict(feed_dict={"image": image}, fetch=["res"])
print(ret)
......@@ -96,7 +96,7 @@ python3 tools/train.py -c configs/det/det_mv3_db.yml \
# 单机多卡训练,通过 --gpus 参数设置使用的GPU ID
python3 -m paddle.distributed.launch --gpus '0,1,2,3' tools/train.py -c configs/det/det_mv3_db.yml \
-o Global.pretrained_model=./pretrain_models/MobileNetV3_large_x0_5_pretrained
# 多机多卡训练,通过 --ips 参数设置使用的机器IP地址,通过 --gpus 参数设置使用的GPU ID
python3 -m paddle.distributed.launch --ips="xx.xx.xx.xx,xx.xx.xx.xx" --gpus '0,1,2,3' tools/train.py -c configs/det/det_mv3_db.yml \
-o Global.pretrained_model=./pretrain_models/MobileNetV3_large_x0_5_pretrained
......@@ -104,14 +104,14 @@ python3 -m paddle.distributed.launch --ips="xx.xx.xx.xx,xx.xx.xx.xx" --gpus '0,1
上述指令中,通过-c 选择训练使用configs/det/det_db_mv3.yml配置文件。
有关配置文件的详细解释,请参考[链接](./config.md)
您也可以通过-o参数在不需要修改yml文件的情况下,改变训练的参数,比如,调整训练的学习率为0.0001
```shell
python3 tools/train.py -c configs/det/det_mv3_db.yml -o Optimizer.base_lr=0.0001
```
**注意:** 采用多机多卡训练时,需要替换上面命令中的ips值为您机器的地址,机器之间需要能够相互ping通。查看机器ip地址的命令为`ifconfig`
**注意:** 采用多机多卡训练时,需要替换上面命令中的ips值为您机器的地址,机器之间需要能够相互ping通。另外,训练时需要在多个机器上分别启动命令。查看机器ip地址的命令为`ifconfig`
如果您想进一步加快训练速度,可以使用[自动混合精度训练](https://www.paddlepaddle.org.cn/documentation/docs/zh/guides/01_paddle2.0_introduction/basic_concept/amp_cn.html), 以单机单卡为例,命令如下:
```shell
python3 tools/train.py -c configs/det/det_mv3_db.yml \
......
......@@ -420,3 +420,5 @@ im_show.save('result.jpg')
| cls | 前向时是否启动分类 (命令行模式下使用use_angle_cls控制前向是否启动分类) | FALSE |
| show_log | 是否打印det和rec等信息 | FALSE |
| type | 执行ocr或者表格结构化, 值可选['ocr','structure'] | ocr |
| ocr_version | OCR模型版本,可选PP-OCRv2, PP-OCR。PP-OCRv2 目前仅支持中文的检测和识别模型,PP-OCR支持中文的检测,识别,多语种识别,方向分类器等模型 | PP-OCRv2 |
| structure_version | 表格结构化模型版本,可选 STRUCTURE。STRUCTURE支持表格结构化模型 | STRUCTURE |
......@@ -98,14 +98,14 @@ python3 tools/train.py -c configs/det/det_mv3_db.yml -o \
# multi-GPU training
# Set the GPU ID used by the '--gpus' parameter.
python3 -m paddle.distributed.launch --gpus '0,1,2,3' tools/train.py -c configs/det/det_mv3_db.yml -o Global.pretrained_model=./pretrain_models/MobileNetV3_large_x0_5_pretrained
# multi-Node, multi-GPU training
# Set the IPs of your nodes used by the '--ips' parameter. Set the GPU ID used by the '--gpus' parameter.
python3 -m paddle.distributed.launch --ips="xx.xx.xx.xx,xx.xx.xx.xx" --gpus '0,1,2,3' tools/train.py -c configs/det/det_mv3_db.yml \
-o Global.pretrained_model=./pretrain_models/MobileNetV3_large_x0_5_pretrained
```
**Note:** For multi-Node multi-GPU training, you need to replace the `ips` value in the preceding command with the address of your machine, and the machines must be able to ping each other. The command for viewing the IP address of the machine is `ifconfig`.
**Note:** For multi-Node multi-GPU training, you need to replace the `ips` value in the preceding command with the address of your machine, and the machines must be able to ping each other. In addition, it requires activating commands separately on multiple machines when we start the training. The command for viewing the IP address of the machine is `ifconfig`.
If you want to further speed up the training, you can use [automatic mixed precision training](https://www.paddlepaddle.org.cn/documentation/docs/zh/guides/01_paddle2.0_introduction/basic_concept/amp_en.html). for single card training, the command is as follows:
```
python3 tools/train.py -c configs/det/det_mv3_db.yml \
......
......@@ -366,4 +366,6 @@ im_show.save('result.jpg')
| rec | Enable recognition when `ppocr.ocr` func exec | TRUE |
| cls | Enable classification when `ppocr.ocr` func exec((Use use_angle_cls in command line mode to control whether to start classification in the forward direction) | FALSE |
| show_log | Whether to print log in det and rec | FALSE |
| type | Perform ocr or table structuring, the value is selected in ['ocr','structure'] | ocr |
\ No newline at end of file
| type | Perform ocr or table structuring, the value is selected in ['ocr','structure'] | ocr |
| ocr_version | OCR Model version number, the current model support list is as follows: PP-OCRv2 support Chinese detection and recognition model, PP-OCR support Chinese detection, recognition and direction classifier, multilingual recognition model | PP-OCRv2 |
| structure_version | table structure Model version number, the current model support list is as follows: STRUCTURE support english table structure model | STRUCTURE |
......@@ -16,6 +16,9 @@ import os
import sys
__dir__ = os.path.dirname(__file__)
import paddle
sys.path.append(os.path.join(__dir__, ''))
import cv2
......@@ -29,7 +32,7 @@ from ppocr.utils.logging import get_logger
logger = get_logger()
from ppocr.utils.utility import check_and_read_gif, get_image_file_list
from ppocr.utils.network import maybe_download, download_with_progressbar, is_link, confirm_model_dir_url
from tools.infer.utility import draw_ocr, str2bool
from tools.infer.utility import draw_ocr, str2bool, check_gpu
from ppstructure.utility import init_args, draw_structure_result
from ppstructure.predict_system import OCRSystem, save_structure_res
......@@ -39,130 +42,137 @@ __all__ = [
]
SUPPORT_DET_MODEL = ['DB']
VERSION = '2.2.1'
VERSION = '2.3.0.1'
SUPPORT_REC_MODEL = ['CRNN']
BASE_DIR = os.path.expanduser("~/.paddleocr/")
DEFAULT_MODEL_VERSION = '2.0'
DEFAULT_OCR_MODEL_VERSION = 'PP-OCR'
DEFAULT_STRUCTURE_MODEL_VERSION = 'STRUCTURE'
MODEL_URLS = {
'2.1': {
'det': {
'ch': {
'url':
'https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_det_infer.tar',
},
},
'rec': {
'ch': {
'url':
'https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_rec_infer.tar',
'dict_path': './ppocr/utils/ppocr_keys_v1.txt'
}
}
},
'2.0': {
'det': {
'ch': {
'url':
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_det_infer.tar',
},
'en': {
'url':
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/en_ppocr_mobile_v2.0_det_infer.tar',
'OCR': {
'PP-OCRv2': {
'det': {
'ch': {
'url':
'https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_det_infer.tar',
},
},
'structure': {
'url':
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_det_infer.tar'
'rec': {
'ch': {
'url':
'https://paddleocr.bj.bcebos.com/PP-OCRv2/chinese/ch_PP-OCRv2_rec_infer.tar',
'dict_path': './ppocr/utils/ppocr_keys_v1.txt'
}
}
},
'rec': {
'ch': {
'url':
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_rec_infer.tar',
'dict_path': './ppocr/utils/ppocr_keys_v1.txt'
},
'en': {
'url':
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/en_number_mobile_v2.0_rec_infer.tar',
'dict_path': './ppocr/utils/en_dict.txt'
},
'french': {
'url':
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/french_mobile_v2.0_rec_infer.tar',
'dict_path': './ppocr/utils/dict/french_dict.txt'
},
'german': {
'url':
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/german_mobile_v2.0_rec_infer.tar',
'dict_path': './ppocr/utils/dict/german_dict.txt'
},
'korean': {
'url':
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/korean_mobile_v2.0_rec_infer.tar',
'dict_path': './ppocr/utils/dict/korean_dict.txt'
},
'japan': {
'url':
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/japan_mobile_v2.0_rec_infer.tar',
'dict_path': './ppocr/utils/dict/japan_dict.txt'
},
'chinese_cht': {
'url':
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/chinese_cht_mobile_v2.0_rec_infer.tar',
'dict_path': './ppocr/utils/dict/chinese_cht_dict.txt'
},
'ta': {
'url':
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/ta_mobile_v2.0_rec_infer.tar',
'dict_path': './ppocr/utils/dict/ta_dict.txt'
DEFAULT_OCR_MODEL_VERSION: {
'det': {
'ch': {
'url':
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_det_infer.tar',
},
'en': {
'url':
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/en_ppocr_mobile_v2.0_det_infer.tar',
},
'structure': {
'url':
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_det_infer.tar'
}
},
'te': {
'url':
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/te_mobile_v2.0_rec_infer.tar',
'dict_path': './ppocr/utils/dict/te_dict.txt'
'rec': {
'ch': {
'url':
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_rec_infer.tar',
'dict_path': './ppocr/utils/ppocr_keys_v1.txt'
},
'en': {
'url':
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/en_number_mobile_v2.0_rec_infer.tar',
'dict_path': './ppocr/utils/en_dict.txt'
},
'french': {
'url':
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/french_mobile_v2.0_rec_infer.tar',
'dict_path': './ppocr/utils/dict/french_dict.txt'
},
'german': {
'url':
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/german_mobile_v2.0_rec_infer.tar',
'dict_path': './ppocr/utils/dict/german_dict.txt'
},
'korean': {
'url':
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/korean_mobile_v2.0_rec_infer.tar',
'dict_path': './ppocr/utils/dict/korean_dict.txt'
},
'japan': {
'url':
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/japan_mobile_v2.0_rec_infer.tar',
'dict_path': './ppocr/utils/dict/japan_dict.txt'
},
'chinese_cht': {
'url':
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/chinese_cht_mobile_v2.0_rec_infer.tar',
'dict_path': './ppocr/utils/dict/chinese_cht_dict.txt'
},
'ta': {
'url':
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/ta_mobile_v2.0_rec_infer.tar',
'dict_path': './ppocr/utils/dict/ta_dict.txt'
},
'te': {
'url':
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/te_mobile_v2.0_rec_infer.tar',
'dict_path': './ppocr/utils/dict/te_dict.txt'
},
'ka': {
'url':
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/ka_mobile_v2.0_rec_infer.tar',
'dict_path': './ppocr/utils/dict/ka_dict.txt'
},
'latin': {
'url':
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/latin_ppocr_mobile_v2.0_rec_infer.tar',
'dict_path': './ppocr/utils/dict/latin_dict.txt'
},
'arabic': {
'url':
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/arabic_ppocr_mobile_v2.0_rec_infer.tar',
'dict_path': './ppocr/utils/dict/arabic_dict.txt'
},
'cyrillic': {
'url':
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/cyrillic_ppocr_mobile_v2.0_rec_infer.tar',
'dict_path': './ppocr/utils/dict/cyrillic_dict.txt'
},
'devanagari': {
'url':
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/devanagari_ppocr_mobile_v2.0_rec_infer.tar',
'dict_path': './ppocr/utils/dict/devanagari_dict.txt'
},
'structure': {
'url':
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_rec_infer.tar',
'dict_path': 'ppocr/utils/dict/table_dict.txt'
}
},
'ka': {
'url':
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/ka_mobile_v2.0_rec_infer.tar',
'dict_path': './ppocr/utils/dict/ka_dict.txt'
'cls': {
'ch': {
'url':
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_cls_infer.tar',
}
},
'latin': {
'url':
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/latin_ppocr_mobile_v2.0_rec_infer.tar',
'dict_path': './ppocr/utils/dict/latin_dict.txt'
},
'arabic': {
'url':
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/arabic_ppocr_mobile_v2.0_rec_infer.tar',
'dict_path': './ppocr/utils/dict/arabic_dict.txt'
},
'cyrillic': {
'url':
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/cyrillic_ppocr_mobile_v2.0_rec_infer.tar',
'dict_path': './ppocr/utils/dict/cyrillic_dict.txt'
},
'devanagari': {
'url':
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/devanagari_ppocr_mobile_v2.0_rec_infer.tar',
'dict_path': './ppocr/utils/dict/devanagari_dict.txt'
},
'structure': {
'url':
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_rec_infer.tar',
'dict_path': 'ppocr/utils/dict/table_dict.txt'
}
},
'cls': {
'ch': {
'url':
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_cls_infer.tar',
}
},
'table': {
'en': {
'url':
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_structure_infer.tar',
'dict_path': 'ppocr/utils/dict/table_structure_dict.txt'
}
},
'STRUCTURE': {
DEFAULT_STRUCTURE_MODEL_VERSION: {
'table': {
'en': {
'url':
'https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_structure_infer.tar',
'dict_path': 'ppocr/utils/dict/table_structure_dict.txt'
}
}
}
}
......@@ -177,7 +187,20 @@ def parse_args(mMain=True):
parser.add_argument("--det", type=str2bool, default=True)
parser.add_argument("--rec", type=str2bool, default=True)
parser.add_argument("--type", type=str, default='ocr')
parser.add_argument("--version", type=str, default='2.1')
parser.add_argument(
"--ocr_version",
type=str,
default='PP-OCRv2',
help='OCR Model version, the current model support list is as follows: '
'1. PP-OCRv2 Support Chinese detection and recognition model. '
'2. PP-OCR support Chinese detection, recognition and direction classifier and multilingual recognition model.'
)
parser.add_argument(
"--structure_version",
type=str,
default='STRUCTURE',
help='Model version, the current model support list is as follows:'
' 1. STRUCTURE Support en table structure model.')
for action in parser._actions:
if action.dest in ['rec_char_dict_path', 'table_char_dict_path']:
......@@ -215,9 +238,9 @@ def parse_lang(lang):
lang = "cyrillic"
elif lang in devanagari_lang:
lang = "devanagari"
assert lang in MODEL_URLS[DEFAULT_MODEL_VERSION][
assert lang in MODEL_URLS['OCR'][DEFAULT_OCR_MODEL_VERSION][
'rec'], 'param lang must in {}, but got {}'.format(
MODEL_URLS[DEFAULT_MODEL_VERSION]['rec'].keys(), lang)
MODEL_URLS['OCR'][DEFAULT_OCR_MODEL_VERSION]['rec'].keys(), lang)
if lang == "ch":
det_lang = "ch"
elif lang == 'structure':
......@@ -227,33 +250,41 @@ def parse_lang(lang):
return lang, det_lang
def get_model_config(version, model_type, lang):
if version not in MODEL_URLS:
logger.warning('version {} not in {}, use version {} instead'.format(
version, MODEL_URLS.keys(), DEFAULT_MODEL_VERSION))
def get_model_config(type, version, model_type, lang):
if type == 'OCR':
DEFAULT_MODEL_VERSION = DEFAULT_OCR_MODEL_VERSION
elif type == 'STRUCTURE':
DEFAULT_MODEL_VERSION = DEFAULT_STRUCTURE_MODEL_VERSION
else:
raise NotImplementedError
model_urls = MODEL_URLS[type]
if version not in model_urls:
logger.warning('version {} not in {}, auto switch to version {}'.format(
version, model_urls.keys(), DEFAULT_MODEL_VERSION))
version = DEFAULT_MODEL_VERSION
if model_type not in MODEL_URLS[version]:
if model_type in MODEL_URLS[DEFAULT_MODEL_VERSION]:
if model_type not in model_urls[version]:
if model_type in model_urls[DEFAULT_MODEL_VERSION]:
logger.warning(
'version {} not support {} models, use version {} instead'.
'version {} not support {} models, auto switch to version {}'.
format(version, model_type, DEFAULT_MODEL_VERSION))
version = DEFAULT_MODEL_VERSION
else:
logger.error('{} models is not support, we only support {}'.format(
model_type, MODEL_URLS[DEFAULT_MODEL_VERSION].keys()))
model_type, model_urls[DEFAULT_MODEL_VERSION].keys()))
sys.exit(-1)
if lang not in MODEL_URLS[version][model_type]:
if lang in MODEL_URLS[DEFAULT_MODEL_VERSION][model_type]:
logger.warning('lang {} is not support in {}, use {} instead'.
format(lang, version, DEFAULT_MODEL_VERSION))
if lang not in model_urls[version][model_type]:
if lang in model_urls[DEFAULT_MODEL_VERSION][model_type]:
logger.warning(
'lang {} is not support in {}, auto switch to version {}'.
format(lang, version, DEFAULT_MODEL_VERSION))
version = DEFAULT_MODEL_VERSION
else:
logger.error(
'lang {} is not support, we only support {} for {} models'.
format(lang, MODEL_URLS[DEFAULT_MODEL_VERSION][model_type].keys(
format(lang, model_urls[DEFAULT_MODEL_VERSION][model_type].keys(
), model_type))
sys.exit(-1)
return MODEL_URLS[version][model_type][lang]
return model_urls[version][model_type][lang]
class PaddleOCR(predict_system.TextSystem):
......@@ -265,23 +296,28 @@ class PaddleOCR(predict_system.TextSystem):
"""
params = parse_args(mMain=False)
params.__dict__.update(**kwargs)
params.use_gpu = check_gpu(params.use_gpu)
if not params.show_log:
logger.setLevel(logging.INFO)
self.use_angle_cls = params.use_angle_cls
lang, det_lang = parse_lang(params.lang)
# init model dir
det_model_config = get_model_config(params.version, 'det', det_lang)
det_model_config = get_model_config('OCR', params.ocr_version, 'det',
det_lang)
params.det_model_dir, det_url = confirm_model_dir_url(
params.det_model_dir,
os.path.join(BASE_DIR, VERSION, 'ocr', 'det', det_lang),
det_model_config['url'])
rec_model_config = get_model_config(params.version, 'rec', lang)
rec_model_config = get_model_config('OCR', params.ocr_version, 'rec',
lang)
params.rec_model_dir, rec_url = confirm_model_dir_url(
params.rec_model_dir,
os.path.join(BASE_DIR, VERSION, 'ocr', 'rec', lang),
rec_model_config['url'])
cls_model_config = get_model_config(params.version, 'cls', 'ch')
cls_model_config = get_model_config('OCR', params.ocr_version, 'cls',
'ch')
params.cls_model_dir, cls_url = confirm_model_dir_url(
params.cls_model_dir,
os.path.join(BASE_DIR, VERSION, 'ocr', 'cls'),
......@@ -362,22 +398,27 @@ class PPStructure(OCRSystem):
def __init__(self, **kwargs):
params = parse_args(mMain=False)
params.__dict__.update(**kwargs)
params.use_gpu = check_gpu(params.use_gpu)
if not params.show_log:
logger.setLevel(logging.INFO)
lang, det_lang = parse_lang(params.lang)
# init model dir
det_model_config = get_model_config(params.version, 'det', det_lang)
det_model_config = get_model_config('OCR', params.ocr_version, 'det',
det_lang)
params.det_model_dir, det_url = confirm_model_dir_url(
params.det_model_dir,
os.path.join(BASE_DIR, VERSION, 'ocr', 'det', det_lang),
det_model_config['url'])
rec_model_config = get_model_config(params.version, 'rec', lang)
rec_model_config = get_model_config('OCR', params.ocr_version, 'rec',
lang)
params.rec_model_dir, rec_url = confirm_model_dir_url(
params.rec_model_dir,
os.path.join(BASE_DIR, VERSION, 'ocr', 'rec', lang),
rec_model_config['url'])
table_model_config = get_model_config(params.version, 'table', 'en')
table_model_config = get_model_config(
'STRUCTURE', params.structure_version, 'table', 'en')
params.table_model_dir, table_url = confirm_model_dir_url(
params.table_model_dir,
os.path.join(BASE_DIR, VERSION, 'ocr', 'table'),
......
......@@ -23,32 +23,40 @@ import numpy as np
class TableAttentionHead(nn.Layer):
def __init__(self, in_channels, hidden_size, loc_type, in_max_len=488, **kwargs):
def __init__(self,
in_channels,
hidden_size,
loc_type,
in_max_len=488,
max_text_length=100,
max_elem_length=800,
max_cell_num=500,
**kwargs):
super(TableAttentionHead, self).__init__()
self.input_size = in_channels[-1]
self.hidden_size = hidden_size
self.elem_num = 30
self.max_text_length = 100
self.max_elem_length = 500
self.max_cell_num = 500
self.max_text_length = max_text_length
self.max_elem_length = max_elem_length
self.max_cell_num = max_cell_num
self.structure_attention_cell = AttentionGRUCell(
self.input_size, hidden_size, self.elem_num, use_gru=False)
self.structure_generator = nn.Linear(hidden_size, self.elem_num)
self.loc_type = loc_type
self.in_max_len = in_max_len
if self.loc_type == 1:
self.loc_generator = nn.Linear(hidden_size, 4)
else:
if self.in_max_len == 640:
self.loc_fea_trans = nn.Linear(400, self.max_elem_length+1)
self.loc_fea_trans = nn.Linear(400, self.max_elem_length + 1)
elif self.in_max_len == 800:
self.loc_fea_trans = nn.Linear(625, self.max_elem_length+1)
self.loc_fea_trans = nn.Linear(625, self.max_elem_length + 1)
else:
self.loc_fea_trans = nn.Linear(256, self.max_elem_length+1)
self.loc_fea_trans = nn.Linear(256, self.max_elem_length + 1)
self.loc_generator = nn.Linear(self.input_size + hidden_size, 4)
def _char_to_onehot(self, input_char, onehot_dim):
input_ont_hot = F.one_hot(input_char, onehot_dim)
return input_ont_hot
......@@ -60,16 +68,16 @@ class TableAttentionHead(nn.Layer):
if len(fea.shape) == 3:
pass
else:
last_shape = int(np.prod(fea.shape[2:])) # gry added
last_shape = int(np.prod(fea.shape[2:])) # gry added
fea = paddle.reshape(fea, [fea.shape[0], fea.shape[1], last_shape])
fea = fea.transpose([0, 2, 1]) # (NTC)(batch, width, channels)
batch_size = fea.shape[0]
hidden = paddle.zeros((batch_size, self.hidden_size))
output_hiddens = []
if self.training and targets is not None:
structure = targets[0]
for i in range(self.max_elem_length+1):
for i in range(self.max_elem_length + 1):
elem_onehots = self._char_to_onehot(
structure[:, i], onehot_dim=self.elem_num)
(outputs, hidden), alpha = self.structure_attention_cell(
......@@ -96,7 +104,7 @@ class TableAttentionHead(nn.Layer):
alpha = None
max_elem_length = paddle.to_tensor(self.max_elem_length)
i = 0
while i < max_elem_length+1:
while i < max_elem_length + 1:
elem_onehots = self._char_to_onehot(
temp_elem, onehot_dim=self.elem_num)
(outputs, hidden), alpha = self.structure_attention_cell(
......@@ -105,7 +113,7 @@ class TableAttentionHead(nn.Layer):
structure_probs_step = self.structure_generator(outputs)
temp_elem = structure_probs_step.argmax(axis=1, dtype="int32")
i += 1
output = paddle.concat(output_hiddens, axis=1)
structure_probs = self.structure_generator(output)
structure_probs = F.softmax(structure_probs)
......@@ -119,9 +127,9 @@ class TableAttentionHead(nn.Layer):
loc_concat = paddle.concat([output, loc_fea], axis=2)
loc_preds = self.loc_generator(loc_concat)
loc_preds = F.sigmoid(loc_preds)
return {'structure_probs':structure_probs, 'loc_preds':loc_preds}
return {'structure_probs': structure_probs, 'loc_preds': loc_preds}
class AttentionGRUCell(nn.Layer):
def __init__(self, input_size, hidden_size, num_embeddings, use_gru=False):
super(AttentionGRUCell, self).__init__()
......
......@@ -24,15 +24,17 @@ from ppocr.utils.logging import get_logger
def download_with_progressbar(url, save_path):
logger = get_logger()
response = requests.get(url, stream=True)
total_size_in_bytes = int(response.headers.get('content-length', 0))
block_size = 1024 # 1 Kibibyte
progress_bar = tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True)
with open(save_path, 'wb') as file:
for data in response.iter_content(block_size):
progress_bar.update(len(data))
file.write(data)
progress_bar.close()
if total_size_in_bytes == 0 or progress_bar.n != total_size_in_bytes:
if response.status_code == 200:
total_size_in_bytes = int(response.headers.get('content-length', 1))
block_size = 1024 # 1 Kibibyte
progress_bar = tqdm(
total=total_size_in_bytes, unit='iB', unit_scale=True)
with open(save_path, 'wb') as file:
for data in response.iter_content(block_size):
progress_bar.update(len(data))
file.write(data)
progress_bar.close()
else:
logger.error("Something went wrong while downloading models")
sys.exit(0)
......@@ -45,7 +47,7 @@ def maybe_download(model_storage_directory, url):
if not os.path.exists(
os.path.join(model_storage_directory, 'inference.pdiparams')
) or not os.path.exists(
os.path.join(model_storage_directory, 'inference.pdmodel')):
os.path.join(model_storage_directory, 'inference.pdmodel')):
assert url.endswith('.tar'), 'Only supports tar compressed package'
tmp_path = os.path.join(model_storage_directory, url.split('/')[-1])
print('download {} to {}'.format(url, tmp_path))
......
......@@ -16,26 +16,28 @@ IFS=$'\n'
lines=(${dataline})
# parser paddle2onnx
padlle2onnx_cmd=$(func_parser_value "${lines[1]}")
infer_model_dir_key=$(func_parser_key "${lines[2]}")
infer_model_dir_value=$(func_parser_value "${lines[2]}")
model_filename_key=$(func_parser_key "${lines[3]}")
model_filename_value=$(func_parser_value "${lines[3]}")
params_filename_key=$(func_parser_key "${lines[4]}")
params_filename_value=$(func_parser_value "${lines[4]}")
save_file_key=$(func_parser_key "${lines[5]}")
save_file_value=$(func_parser_value "${lines[5]}")
opset_version_key=$(func_parser_key "${lines[6]}")
opset_version_value=$(func_parser_value "${lines[6]}")
enable_onnx_checker_key=$(func_parser_key "${lines[7]}")
enable_onnx_checker_value=$(func_parser_value "${lines[7]}")
model_name=$(func_parser_value "${lines[1]}")
python=$(func_parser_value "${lines[2]}")
padlle2onnx_cmd=$(func_parser_value "${lines[3]}")
infer_model_dir_key=$(func_parser_key "${lines[4]}")
infer_model_dir_value=$(func_parser_value "${lines[4]}")
model_filename_key=$(func_parser_key "${lines[5]}")
model_filename_value=$(func_parser_value "${lines[5]}")
params_filename_key=$(func_parser_key "${lines[6]}")
params_filename_value=$(func_parser_value "${lines[6]}")
save_file_key=$(func_parser_key "${lines[7]}")
save_file_value=$(func_parser_value "${lines[7]}")
opset_version_key=$(func_parser_key "${lines[8]}")
opset_version_value=$(func_parser_value "${lines[8]}")
enable_onnx_checker_key=$(func_parser_key "${lines[9]}")
enable_onnx_checker_value=$(func_parser_value "${lines[9]}")
# parser onnx inference
inference_py=$(func_parser_value "${lines[8]}")
use_gpu_key=$(func_parser_key "${lines[9]}")
use_gpu_value=$(func_parser_value "${lines[9]}")
det_model_key=$(func_parser_key "${lines[10]}")
image_dir_key=$(func_parser_key "${lines[11]}")
image_dir_value=$(func_parser_value "${lines[11]}")
inference_py=$(func_parser_value "${lines[10]}")
use_gpu_key=$(func_parser_key "${lines[11]}")
use_gpu_value=$(func_parser_value "${lines[11]}")
det_model_key=$(func_parser_key "${lines[12]}")
image_dir_key=$(func_parser_key "${lines[13]}")
image_dir_value=$(func_parser_value "${lines[13]}")
LOG_PATH="./test_tipc/output"
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
......@@ -17,7 +17,7 @@ import os
import sys
import cv2
import numpy as np
import json
import paddle
from PIL import Image, ImageDraw, ImageFont
import math
from paddle import inference
......@@ -601,5 +601,12 @@ def get_rotate_crop_image(img, points):
return dst_img
def check_gpu(use_gpu):
if use_gpu and not paddle.is_compiled_with_cuda():
use_gpu = False
return use_gpu
if __name__ == '__main__':
pass
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册