diff --git a/configs/table/table_mv3.yml b/configs/table/table_mv3.yml index 9355a236e15b60db18e8715c2702701fd5d36c71..9d286f4153eaab44bf0d259bbad4a0b3b8ada568 100755 --- a/configs/table/table_mv3.yml +++ b/configs/table/table_mv3.yml @@ -43,7 +43,6 @@ Architecture: Head: name: TableAttentionHead hidden_size: 256 - loc_type: 2 max_text_length: *max_text_length loc_reg_num: &loc_reg_num 4 diff --git a/ppocr/modeling/heads/table_att_head.py b/ppocr/modeling/heads/table_att_head.py index d3c86e22b02e08c18d8d5cb193f2ffb8b07ad785..50910c5b73aa2a41f329d7222fc8c632509b4c91 100644 --- a/ppocr/modeling/heads/table_att_head.py +++ b/ppocr/modeling/heads/table_att_head.py @@ -16,6 +16,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import math import paddle import paddle.nn as nn from paddle import ParamAttr @@ -42,7 +43,6 @@ class TableAttentionHead(nn.Layer): def __init__(self, in_channels, hidden_size, - loc_type, in_max_len=488, max_text_length=800, out_channels=30, @@ -57,20 +57,16 @@ class TableAttentionHead(nn.Layer): self.structure_attention_cell = AttentionGRUCell( self.input_size, hidden_size, self.out_channels, use_gru=False) self.structure_generator = nn.Linear(hidden_size, self.out_channels) - self.loc_type = loc_type self.in_max_len = in_max_len - if self.loc_type == 1: - self.loc_generator = nn.Linear(hidden_size, 4) + if self.in_max_len == 640: + self.loc_fea_trans = nn.Linear(400, self.max_text_length + 1) + elif self.in_max_len == 800: + self.loc_fea_trans = nn.Linear(625, self.max_text_length + 1) else: - if self.in_max_len == 640: - self.loc_fea_trans = nn.Linear(400, self.max_text_length + 1) - elif self.in_max_len == 800: - self.loc_fea_trans = nn.Linear(625, self.max_text_length + 1) - else: - self.loc_fea_trans = nn.Linear(256, self.max_text_length + 1) - self.loc_generator = nn.Linear(self.input_size + hidden_size, - loc_reg_num) + self.loc_fea_trans = nn.Linear(256, self.max_text_length + 1) + self.loc_generator = nn.Linear(self.input_size + hidden_size, + loc_reg_num) def _char_to_onehot(self, input_char, onehot_dim): input_ont_hot = F.one_hot(input_char, onehot_dim) @@ -80,16 +76,13 @@ class TableAttentionHead(nn.Layer): # if and else branch are both needed when you want to assign a variable # if you modify the var in just one branch, then the modification will not work. fea = inputs[-1] - if len(fea.shape) == 3: - pass - else: - last_shape = int(np.prod(fea.shape[2:])) # gry added - fea = paddle.reshape(fea, [fea.shape[0], fea.shape[1], last_shape]) - fea = fea.transpose([0, 2, 1]) # (NTC)(batch, width, channels) + last_shape = int(np.prod(fea.shape[2:])) # gry added + fea = paddle.reshape(fea, [fea.shape[0], fea.shape[1], last_shape]) + fea = fea.transpose([0, 2, 1]) # (NTC)(batch, width, channels) batch_size = fea.shape[0] hidden = paddle.zeros((batch_size, self.hidden_size)) - output_hiddens = [] + output_hiddens = paddle.zeros((batch_size, self.max_text_length + 1, self.hidden_size)) if self.training and targets is not None: structure = targets[0] for i in range(self.max_text_length + 1): @@ -97,7 +90,8 @@ class TableAttentionHead(nn.Layer): structure[:, i], onehot_dim=self.out_channels) (outputs, hidden), alpha = self.structure_attention_cell( hidden, fea, elem_onehots) - output_hiddens.append(paddle.unsqueeze(outputs, axis=1)) + output_hiddens[:, i, :] = outputs + # output_hiddens.append(paddle.unsqueeze(outputs, axis=1)) output = paddle.concat(output_hiddens, axis=1) structure_probs = self.structure_generator(output) if self.loc_type == 1: @@ -118,30 +112,25 @@ class TableAttentionHead(nn.Layer): outputs = None alpha = None max_text_length = paddle.to_tensor(self.max_text_length) - i = 0 - while i < max_text_length + 1: + for i in range(max_text_length + 1): elem_onehots = self._char_to_onehot( temp_elem, onehot_dim=self.out_channels) (outputs, hidden), alpha = self.structure_attention_cell( hidden, fea, elem_onehots) - output_hiddens.append(paddle.unsqueeze(outputs, axis=1)) + output_hiddens[:, i, :] = outputs + # output_hiddens.append(paddle.unsqueeze(outputs, axis=1)) structure_probs_step = self.structure_generator(outputs) temp_elem = structure_probs_step.argmax(axis=1, dtype="int32") - i += 1 - output = paddle.concat(output_hiddens, axis=1) + output = output_hiddens structure_probs = self.structure_generator(output) structure_probs = F.softmax(structure_probs) - if self.loc_type == 1: - loc_preds = self.loc_generator(output) - loc_preds = F.sigmoid(loc_preds) - else: - loc_fea = fea.transpose([0, 2, 1]) - loc_fea = self.loc_fea_trans(loc_fea) - loc_fea = loc_fea.transpose([0, 2, 1]) - loc_concat = paddle.concat([output, loc_fea], axis=2) - loc_preds = self.loc_generator(loc_concat) - loc_preds = F.sigmoid(loc_preds) + loc_fea = fea.transpose([0, 2, 1]) + loc_fea = self.loc_fea_trans(loc_fea) + loc_fea = loc_fea.transpose([0, 2, 1]) + loc_concat = paddle.concat([output, loc_fea], axis=2) + loc_preds = self.loc_generator(loc_concat) + loc_preds = F.sigmoid(loc_preds) return {'structure_probs': structure_probs, 'loc_preds': loc_preds} diff --git a/test_tipc/configs/en_table_structure/model_linux_gpu_normal_normal_paddle2onnx_python_linux_cpu.txt b/test_tipc/configs/en_table_structure/model_linux_gpu_normal_normal_paddle2onnx_python_linux_cpu.txt new file mode 100644 index 0000000000000000000000000000000000000000..068c4c6b1d2655b9dcda1120425de7d52d0d543d --- /dev/null +++ b/test_tipc/configs/en_table_structure/model_linux_gpu_normal_normal_paddle2onnx_python_linux_cpu.txt @@ -0,0 +1,17 @@ +===========================paddle2onnx_params=========================== +model_name:en_table_structure +python:python3.7 +2onnx: paddle2onnx +--det_model_dir:./inference/en_ppocr_mobile_v2.0_table_structure_infer/ +--model_filename:inference.pdmodel +--params_filename:inference.pdiparams +--det_save_file:./inference/en_ppocr_mobile_v2.0_table_structure_infer/model.onnx +--rec_model_dir: +--rec_save_file: +--opset_version:10 +--enable_onnx_checker:True +inference:ppstructure/table/predict_structure.py --table_char_dict_path=./ppocr/utils/dict/table_structure_dict.txt +--use_gpu:True|False +--det_model_dir: +--rec_model_dir: +--image_dir:./ppstructure/docs/table/table.jpg \ No newline at end of file diff --git a/test_tipc/configs/layoutxlm_ser/train_linux_gpu_fleet_normal_infer_python_linux_gpu_cpu.txt b/test_tipc/configs/layoutxlm_ser/train_linux_gpu_fleet_normal_infer_python_linux_gpu_cpu.txt index 617b726aa30ae52e9355963999c69d76d069dd65..1e9e9ce6182d060cc6b5ba14cdb4e54528224af4 100644 --- a/test_tipc/configs/layoutxlm_ser/train_linux_gpu_fleet_normal_infer_python_linux_gpu_cpu.txt +++ b/test_tipc/configs/layoutxlm_ser/train_linux_gpu_fleet_normal_infer_python_linux_gpu_cpu.txt @@ -13,7 +13,7 @@ train_infer_img_dir:ppstructure/docs/kie/input/zh_val_42.jpg null:null ## trainer:norm_train -norm_train:tools/train.py -c test_tipc/configs/layoutxlm_ser/ser_layoutxlm_xfund_zh.yml +norm_train:tools/train.py -c test_tipc/configs/layoutxlm_ser/ser_layoutxlm_xfund_zh.yml -o pact_train:null fpgm_train:null distill_train:null diff --git a/test_tipc/configs/slanet/model_linux_gpu_normal_normal_paddle2onnx_python_linux_cpu.txt b/test_tipc/configs/slanet/model_linux_gpu_normal_normal_paddle2onnx_python_linux_cpu.txt new file mode 100644 index 0000000000000000000000000000000000000000..45e4e9e858914dd8596cef10625df8160afe45fb --- /dev/null +++ b/test_tipc/configs/slanet/model_linux_gpu_normal_normal_paddle2onnx_python_linux_cpu.txt @@ -0,0 +1,17 @@ +===========================paddle2onnx_params=========================== +model_name:slanet +python:python3.7 +2onnx: paddle2onnx +--det_model_dir:./inference/ch_ppstructure_mobile_v2.0_SLANet_infer/ +--model_filename:inference.pdmodel +--params_filename:inference.pdiparams +--det_save_file:./inference/ch_ppstructure_mobile_v2.0_SLANet_infer/model.onnx +--rec_model_dir: +--rec_save_file: +--opset_version:10 +--enable_onnx_checker:True +inference:ppstructure/table/predict_structure.py --table_char_dict_path=./ppocr/utils/dict/table_structure_dict_ch.txt +--use_gpu:True|False +--det_model_dir: +--rec_model_dir: +--image_dir:./ppstructure/docs/table/table.jpg \ No newline at end of file diff --git a/test_tipc/configs/slanet/train_linux_gpu_fleet_normal_infer_python_linux_gpu_cpu.txt b/test_tipc/configs/slanet/train_linux_gpu_fleet_normal_infer_python_linux_gpu_cpu.txt new file mode 100644 index 0000000000000000000000000000000000000000..cf962d53e95accbe9beed41703ffd8a14eb28876 --- /dev/null +++ b/test_tipc/configs/slanet/train_linux_gpu_fleet_normal_infer_python_linux_gpu_cpu.txt @@ -0,0 +1,53 @@ +===========================train_params=========================== +model_name:slanet +python:python3.7 +gpu_list:192.168.0.1,192.168.0.2;0,1 +Global.use_gpu:True +Global.auto_cast:fp32 +Global.epoch_num:lite_train_lite_infer=3|whole_train_whole_infer=50 +Global.save_model_dir:./output/ +Train.loader.batch_size_per_card:lite_train_lite_infer=16|whole_train_whole_infer=128 +Global.pretrained_model:./pretrain_models/en_ppstructure_mobile_v2.0_SLANet_train/best_accuracy +train_model_name:latest +train_infer_img_dir:./ppstructure/docs/table/table.jpg +null:null +## +trainer:norm_train +norm_train:tools/train.py -c test_tipc/configs/slanet/SLANet.yml -o +pact_train:null +fpgm_train:null +distill_train:null +null:null +null:null +## +===========================eval_params=========================== +eval:null +null:null +## +===========================infer_params=========================== +Global.save_inference_dir:./output/ +Global.checkpoints: +norm_export:tools/export_model.py -c test_tipc/configs/slanet/SLANet.yml -o +quant_export: +fpgm_export: +distill_export:null +export1:null +export2:null +## +infer_model:./inference/en_ppstructure_mobile_v2.0_SLANet_train +infer_export:null +infer_quant:False +inference:ppstructure/table/predict_table.py --det_model_dir=./inference/en_ppocr_mobile_v2.0_table_det_infer --rec_model_dir=./inference/en_ppocr_mobile_v2.0_table_rec_infer --rec_char_dict_path=./ppocr/utils/dict/table_dict.txt --table_char_dict_path=./ppocr/utils/dict/table_structure_dict.txt --image_dir=./ppstructure/docs/table/table.jpg --det_limit_side_len=736 --det_limit_type=min --output ./output/table +--use_gpu:True|False +--enable_mkldnn:False +--cpu_threads:6 +--rec_batch_num:1 +--use_tensorrt:False +--precision:fp32 +--table_model_dir: +--image_dir:./ppstructure/docs/table/table.jpg +null:null +--benchmark:False +null:null +===========================infer_benchmark_params========================== +random_infer_input:[{float32,[3,488,488]}] diff --git a/test_tipc/prepare.sh b/test_tipc/prepare.sh index ecb1e36bb1bb83c6ee2dcf1cb243e6ee60de5dd8..688deac0f379b50865fe6739529f9301ebcd919b 100644 --- a/test_tipc/prepare.sh +++ b/test_tipc/prepare.sh @@ -700,10 +700,18 @@ if [ ${MODE} = "cpp_infer" ];then wget -nc -P ./inference https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_rec_infer.tar --no-check-certificate cd ./inference && tar xf ch_PP-OCRv3_det_infer.tar && tar xf ch_PP-OCRv3_rec_infer.tar && tar xf ch_det_data_50.tar && cd ../ elif [[ ${model_name} =~ "en_table_structure" ]];then - wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_structure_infer.tar --no-check-certificate wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_det_infer.tar --no-check-certificate wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_rec_infer.tar --no-check-certificate - cd ./inference/ && tar xf en_ppocr_mobile_v2.0_table_structure_infer.tar && tar xf en_ppocr_mobile_v2.0_table_det_infer.tar && tar xf en_ppocr_mobile_v2.0_table_rec_infer.tar && cd ../ + + cd ./inference/ && tar xf en_ppocr_mobile_v2.0_table_det_infer.tar && tar xf en_ppocr_mobile_v2.0_table_rec_infer.tar + if [ ${model_name} == "en_table_structure" ]; then + wget -nc https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_structure_infer.tar --no-check-certificate + tar xf en_ppocr_mobile_v2.0_table_structure_infer.tar + elif [ ${model_name} == "en_table_structure_PACT" ]; then + wget -nc https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_structure_slim_infer.tar --no-check-certificate + tar xf en_ppocr_mobile_v2.0_table_structure_slim_infer.tar + fi + cd ../ elif [[ ${model_name} =~ "slanet" ]];then wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/ppstructure/models/slanet/ch_ppstructure_mobile_v2.0_SLANet_infer.tar --no-check-certificate wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_det_infer.tar --no-check-certificate @@ -791,6 +799,12 @@ if [ ${MODE} = "paddle2onnx_infer" ];then wget -nc -P ./inference https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_det_infer.tar --no-check-certificate wget -nc -P ./inference https://paddleocr.bj.bcebos.com/PP-OCRv3/chinese/ch_PP-OCRv3_rec_infer.tar --no-check-certificate cd ./inference && tar xf ch_PP-OCRv3_det_infer.tar && tar xf ch_PP-OCRv3_rec_infer.tar && cd ../ + elif [[ ${model_name} =~ "slanet" ]];then + wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/ppstructure/models/slanet/ch_ppstructure_mobile_v2.0_SLANet_infer.tar --no-check-certificate + cd ./inference/ && tar xf ch_ppstructure_mobile_v2.0_SLANet_infer.tar && cd ../ + elif [[ ${model_name} =~ "en_table_structure" ]];then + wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_structure_infer.tar --no-check-certificate + cd ./inference/ && tar xf en_ppocr_mobile_v2.0_table_structure_infer.tar && cd ../ fi # wget data diff --git a/test_tipc/test_paddle2onnx.sh b/test_tipc/test_paddle2onnx.sh index 04bfb590f7c6e64cf136d3feef8594994cb86877..f035e6bb645a1e7927844232c2bff72f0480e38e 100644 --- a/test_tipc/test_paddle2onnx.sh +++ b/test_tipc/test_paddle2onnx.sh @@ -105,6 +105,19 @@ function func_paddle2onnx(){ eval $trans_model_cmd last_status=${PIPESTATUS[0]} status_check $last_status "${trans_model_cmd}" "${status_log}" "${model_name}" "${trans_rec_log}" + elif [ ${model_name} = "slanet" ] || [ ${model_name} = "en_table_structure" ]; then + # trans det + set_dirname=$(func_set_params "--model_dir" "${det_infer_model_dir_value}") + set_model_filename=$(func_set_params "${model_filename_key}" "${model_filename_value}") + set_params_filename=$(func_set_params "${params_filename_key}" "${params_filename_value}") + set_save_model=$(func_set_params "--save_file" "${det_save_file_value}") + set_opset_version=$(func_set_params "${opset_version_key}" "${opset_version_value}") + set_enable_onnx_checker=$(func_set_params "${enable_onnx_checker_key}" "${enable_onnx_checker_value}") + trans_det_log="${LOG_PATH}/trans_model_det.log" + trans_model_cmd="${padlle2onnx_cmd} ${set_dirname} ${set_model_filename} ${set_params_filename} ${set_save_model} ${set_opset_version} ${set_enable_onnx_checker} --enable_dev_version=True > ${trans_det_log} 2>&1 " + eval $trans_model_cmd + last_status=${PIPESTATUS[0]} + status_check $last_status "${trans_model_cmd}" "${status_log}" "${model_name}" "${trans_det_log}" fi # python inference @@ -117,7 +130,7 @@ function func_paddle2onnx(){ set_det_model_dir=$(func_set_params "${det_model_key}" "${det_save_file_value}") set_rec_model_dir=$(func_set_params "${rec_model_key}" "${rec_save_file_value}") infer_model_cmd="${python} ${inference_py} ${set_gpu} ${set_img_dir} ${set_det_model_dir} ${set_rec_model_dir} --use_onnx=True > ${_save_log_path} 2>&1 " - elif [[ ${model_name} =~ "det" ]]; then + elif [[ ${model_name} =~ "det" ]] || [ ${model_name} = "slanet" ] || [ ${model_name} = "en_table_structure" ]; then set_det_model_dir=$(func_set_params "${det_model_key}" "${det_save_file_value}") infer_model_cmd="${python} ${inference_py} ${set_gpu} ${set_img_dir} ${set_det_model_dir} --use_onnx=True > ${_save_log_path} 2>&1 " elif [[ ${model_name} =~ "rec" ]]; then @@ -136,7 +149,7 @@ function func_paddle2onnx(){ set_det_model_dir=$(func_set_params "${det_model_key}" "${det_save_file_value}") set_rec_model_dir=$(func_set_params "${rec_model_key}" "${rec_save_file_value}") infer_model_cmd="${python} ${inference_py} ${set_gpu} ${set_img_dir} ${set_det_model_dir} ${set_rec_model_dir} --use_onnx=True > ${_save_log_path} 2>&1 " - elif [[ ${model_name} =~ "det" ]]; then + elif [[ ${model_name} =~ "det" ]]|| [ ${model_name} = "slanet" ] || [ ${model_name} = "en_table_structure" ]; then set_det_model_dir=$(func_set_params "${det_model_key}" "${det_save_file_value}") infer_model_cmd="${python} ${inference_py} ${set_gpu} ${set_img_dir} ${set_det_model_dir} --use_onnx=True > ${_save_log_path} 2>&1 " elif [[ ${model_name} =~ "rec" ]]; then