未验证 提交 bb95527f 编写于 作者: Z zhoujun 提交者: GitHub

Merge pull request #6618 from WenmuZhou/tablemaster

[WIP] add TableMaster
Global:
use_gpu: true
epoch_num: 17
log_smooth_window: 20
print_batch_step: 100
save_model_dir: ./output/table_master/
save_epoch_step: 17
eval_batch_step: [0, 6259]
cal_metric_during_train: true
pretrained_model: null
checkpoints:
save_inference_dir: output/table_master/infer
use_visualdl: false
infer_img: ppstructure/docs/table/table.jpg
save_res_path: ./output/table_master
character_dict_path: ppocr/utils/dict/table_master_structure_dict.txt
infer_mode: false
max_text_length: 500
process_total_num: 0
process_cut_num: 0
Optimizer:
name: Adam
beta1: 0.9
beta2: 0.999
lr:
name: MultiStepDecay
learning_rate: 0.001
milestones: [12, 15]
gamma: 0.1
warmup_epoch: 0.02
regularizer:
name: L2
factor: 0.0
Architecture:
model_type: table
algorithm: TableMaster
Backbone:
name: TableResNetExtra
gcb_config:
ratio: 0.0625
headers: 1
att_scale: False
fusion_type: channel_add
layers: [False, True, True, True]
layers: [1,2,5,3]
Head:
name: TableMasterHead
hidden_size: 512
headers: 8
dropout: 0
d_ff: 2024
max_text_length: 500
Loss:
name: TableMasterLoss
ignore_index: 42 # set to len of dict + 3
PostProcess:
name: TableMasterLabelDecode
box_shape: pad
Metric:
name: TableMetric
main_indicator: acc
compute_bbox_metric: False
Train:
dataset:
name: PubTabDataSet
data_dir: train_data/table/pubtabnet/train/
label_file_list: [train_data/table/pubtabnet/PubTabNet_2.0.0_train.jsonl]
transforms:
- DecodeImage:
img_mode: BGR
channel_first: False
- TableMasterLabelEncode:
learn_empty_box: False
merge_no_span_structure: True
replace_empty_cell_token: True
- ResizeTableImage:
max_len: 480
resize_bboxes: True
- PaddingTableImage:
size: [480, 480]
- TableBoxEncode:
use_xywh: True
- NormalizeImage:
scale: 1./255.
mean: [0.5, 0.5, 0.5]
std: [0.5, 0.5, 0.5]
order: hwc
- ToCHWImage: null
- KeepKeys:
keep_keys: [image, structure, bboxes, bbox_masks, shape]
loader:
shuffle: True
batch_size_per_card: 10
drop_last: True
num_workers: 8
Eval:
dataset:
name: PubTabDataSet
data_dir: /home/zhoujun20/table/PubTabNe/pubtabnet/val/
label_file_list: [/home/zhoujun20/table/PubTabNe/pubtabnet/val_500.jsonl]
transforms:
- DecodeImage:
img_mode: BGR
channel_first: False
- TableMasterLabelEncode:
learn_empty_box: False
merge_no_span_structure: True
replace_empty_cell_token: True
- ResizeTableImage:
max_len: 480
resize_bboxes: True
- PaddingTableImage:
size: [480, 480]
- TableBoxEncode:
use_xywh: True
- NormalizeImage:
scale: 1./255.
mean: [0.5, 0.5, 0.5]
std: [0.5, 0.5, 0.5]
order: hwc
- ToCHWImage: null
- KeepKeys:
keep_keys: [image, structure, bboxes, bbox_masks, shape]
loader:
shuffle: False
drop_last: False
batch_size_per_card: 10
num_workers: 8
\ No newline at end of file
......@@ -4,21 +4,20 @@ Global:
log_smooth_window: 20
print_batch_step: 5
save_model_dir: ./output/table_mv3/
save_epoch_step: 3
save_epoch_step: 400
# evaluation is run every 400 iterations after the 0th iteration
eval_batch_step: [0, 400]
cal_metric_during_train: True
pretrained_model:
checkpoints:
checkpoints:
save_inference_dir:
use_visualdl: False
infer_img: doc/table/table.jpg
infer_img: ppstructure/docs/table/table.jpg
save_res_path: output/table_mv3
# for data or label process
character_dict_path: ppocr/utils/dict/table_structure_dict.txt
character_type: en
max_text_length: 100
max_elem_length: 800
max_cell_num: 500
max_text_length: 800
infer_mode: False
process_total_num: 0
process_cut_num: 0
......@@ -44,11 +43,8 @@ Architecture:
Head:
name: TableAttentionHead
hidden_size: 256
l2_decay: 0.00001
loc_type: 2
max_text_length: 100
max_elem_length: 800
max_cell_num: 500
max_text_length: 800
Loss:
name: TableAttentionLoss
......@@ -61,28 +57,34 @@ PostProcess:
Metric:
name: TableMetric
main_indicator: acc
compute_bbox_metric: false # cost many time, set False for training
Train:
dataset:
name: PubTabDataSet
data_dir: train_data/table/pubtabnet/train/
label_file_path: train_data/table/pubtabnet/PubTabNet_2.0.0_train.jsonl
label_file_list: [train_data/table/pubtabnet/PubTabNet_2.0.0_train.jsonl]
transforms:
- DecodeImage: # load image
img_mode: BGR
channel_first: False
- TableLabelEncode:
learn_empty_box: False
merge_no_span_structure: False
replace_empty_cell_token: False
- TableBoxEncode:
- ResizeTableImage:
max_len: 488
- TableLabelEncode:
- NormalizeImage:
scale: 1./255.
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: 'hwc'
- PaddingTableImage:
size: [488, 488]
- ToCHWImage:
- KeepKeys:
keep_keys: ['image', 'structure', 'bbox_list', 'sp_tokens', 'bbox_list_mask']
keep_keys: [ 'image', 'structure', 'bboxes', 'bbox_masks', 'shape' ]
loader:
shuffle: True
batch_size_per_card: 32
......@@ -92,24 +94,29 @@ Train:
Eval:
dataset:
name: PubTabDataSet
data_dir: train_data/table/pubtabnet/val/
label_file_path: train_data/table/pubtabnet/PubTabNet_2.0.0_val.jsonl
data_dir: /home/zhoujun20/table/PubTabNe/pubtabnet/val/
label_file_list: [/home/zhoujun20/table/PubTabNe/pubtabnet/val_500.jsonl]
transforms:
- DecodeImage: # load image
img_mode: BGR
channel_first: False
- TableLabelEncode:
learn_empty_box: False
merge_no_span_structure: False
replace_empty_cell_token: False
- TableBoxEncode:
- ResizeTableImage:
max_len: 488
- TableLabelEncode:
- NormalizeImage:
scale: 1./255.
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: 'hwc'
- PaddingTableImage:
size: [488, 488]
- ToCHWImage:
- KeepKeys:
keep_keys: ['image', 'structure', 'bbox_list', 'sp_tokens', 'bbox_list_mask']
keep_keys: [ 'image', 'structure', 'bboxes', 'bbox_masks', 'shape' ]
loader:
shuffle: False
drop_last: False
......
# FCENet
- [1. 算法简介](#1)
- [2. 环境配置](#2)
- [3. 模型训练、评估、预测](#3)
- [3.1 训练](#3-1)
- [3.2 评估](#3-2)
- [3.3 预测](#3-3)
- [4. 推理部署](#4)
- [4.1 Python推理](#4-1)
- [4.2 C++推理](#4-2)
- [4.3 Serving服务化部署](#4-3)
- [4.4 更多推理部署](#4-4)
- [5. FAQ](#5)
- [1. 算法简介](#1-算法简介)
- [2. 环境配置](#2-环境配置)
- [3. 模型训练、评估、预测](#3-模型训练评估预测)
- [4. 推理部署](#4-推理部署)
- [4.1 Python推理](#41-python推理)
- [4.2 C++推理](#42-c推理)
- [4.3 Serving服务化部署](#43-serving服务化部署)
- [4.4 更多推理部署](#44-更多推理部署)
- [5. FAQ](#5-faq)
- [引用](#引用)
<a name="1"></a>
## 1. 算法简介
......
# OCR算法
- [1. 两阶段算法](#1-两阶段算法)
- [1.1 文本检测算法](#11-文本检测算法)
- [1.2 文本识别算法](#12-文本识别算法)
- [1.1 文本检测算法](#11-文本检测算法)
- [1.2 文本识别算法](#12-文本识别算法)
- [2. 端到端算法](#2-端到端算法)
- [3. 表格识别算法](#3-表格识别算法)
本文给出了PaddleOCR已支持的OCR算法列表,以及每个算法在**英文公开数据集**上的模型和指标,主要用于算法简介和算法性能对比,更多包括中文在内的其他数据集上的模型请参考[PP-OCR v2.0 系列模型下载](./models_list.md)
......@@ -96,3 +97,14 @@
已支持的端到端OCR算法列表(戳链接获取使用教程):
- [x] [PGNet](./algorithm_e2e_pgnet.md)
## 3. 表格识别算法
已支持的表格识别算法列表(戳链接获取使用教程):
- [x] [TableMaster](./algorithm_table_master.md)
在PubTabNet表格识别公开数据集上,算法效果如下:
|模型|骨干网络|配置文件|acc|下载链接|
|---|---|---|---|---|
|TableMaster|TableResNetExtra|[configs/table/table_master.yml](../../configs/table/table_master.yml)|77.47%|[训练模型](https://paddleocr.bj.bcebos.com/ppstructure/models/tablemaster/table_structure_tablemaster_train.tar) / [推理模型](https://paddleocr.bj.bcebos.com/ppstructure/models/tablemaster/table_structure_tablemaster_infer.tar)|
# 表格识别算法-TableMASTER
- [1. 算法简介](#1-算法简介)
- [2. 环境配置](#2-环境配置)
- [3. 模型训练、评估、预测](#3-模型训练评估预测)
- [4. 推理部署](#4-推理部署)
- [4.1 Python推理](#41-python推理)
- [4.2 C++推理部署](#42-c推理部署)
- [4.3 Serving服务化部署](#43-serving服务化部署)
- [4.4 更多推理部署](#44-更多推理部署)
- [5. FAQ](#5-faq)
- [引用](#引用)
<a name="1"></a>
## 1. 算法简介
论文信息:
> [TableMaster: PINGAN-VCGROUP’S SOLUTION FOR ICDAR 2021 COMPETITION ON SCIENTIFIC LITERATURE PARSING TASK B: TABLE RECOGNITION TO HTML](https://arxiv.org/pdf/2105.01848.pdf)
> Ye, Jiaquan and Qi, Xianbiao and He, Yelin and Chen, Yihao and Gu, Dengyi and Gao, Peng and Xiao, Rong
> 2021
在PubTabNet表格识别公开数据集上,算法复现效果如下:
|模型|骨干网络|配置文件|acc|下载链接|
| --- | --- | --- | --- | --- |
|TableMaster|TableResNetExtra|[configs/table/table_master.yml](../../configs/table/table_master.yml)|77.47%|[训练模型](https://paddleocr.bj.bcebos.com/ppstructure/models/tablemaster/table_structure_tablemaster_train.tar)/[推理模型](https://paddleocr.bj.bcebos.com/ppstructure/models/tablemaster/table_structure_tablemaster_infer.tar)|
<a name="2"></a>
## 2. 环境配置
请先参考[《运行环境准备》](./environment.md)配置PaddleOCR运行环境,参考[《项目克隆》](./clone.md)克隆项目代码。
<a name="3"></a>
## 3. 模型训练、评估、预测
上述TableMaster模型使用PubTabNet表格识别公开数据集训练得到,数据集下载可参考 [table_datasets](./dataset/table_datasets.md)
数据下载完成后,请参考[文本识别教程](./recognition.md)进行训练。PaddleOCR对代码进行了模块化,训练不同的模型只需要**更换配置文件**即可。
<a name="4"></a>
## 4. 推理部署
<a name="4-1"></a>
### 4.1 Python推理
首先将训练得到best模型,转换成inference model。以基于TableResNetExtra骨干网络,在PubTabNet数据集训练的模型为例([模型下载地址](https://paddleocr.bj.bcebos.com/contribution/table_master.tar)),可以使用如下命令进行转换:
```shell
# 注意将pretrained_model的路径设置为本地路径。
python3 tools/export_model.py -c configs/table/table_master.yml -o Global.pretrained_model=output/table_master/best_accuracy Global.save_inference_dir=./inference/table_master
```
**注意:**
- 如果您是在自己的数据集上训练的模型,并且调整了字典文件,请注意修改配置文件中的`character_dict_path`是否为所正确的字典文件。
转换成功后,在目录下有三个文件:
```
./inference/table_master/
├── inference.pdiparams # 识别inference模型的参数文件
├── inference.pdiparams.info # 识别inference模型的参数信息,可忽略
└── inference.pdmodel # 识别inference模型的program文件
```
执行如下命令进行模型推理:
```shell
cd ppstructure/
python3.7 table/predict_structure.py --table_model_dir=../output/table_master/table_structure_tablemaster_infer/ --table_algorithm=TableMaster --table_char_dict_path=../ppocr/utils/dict/table_master_structure_dict.txt --table_max_len=480 --image_dir=docs/table/table.jpg
# 预测文件夹下所有图像时,可修改image_dir为文件夹,如 --image_dir='docs/table'。
```
执行命令后,上面图像的预测结果(结构信息和表格中每个单元格的坐标)会打印到屏幕上,同时会保存单元格坐标的可视化结果。示例如下:
结果如下:
```shell
[2022/06/16 13:06:54] ppocr INFO: result: ['<html>', '<body>', '<table>', '<thead>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '</thead>', '<tbody>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '</tbody>', '</table>', '</body>', '</html>'], [[72.17591094970703, 10.759100914001465, 60.29658508300781, 16.6805362701416], [161.85562133789062, 10.884308815002441, 14.9495210647583, 16.727018356323242], [277.79876708984375, 29.54340362548828, 31.490320205688477, 18.143272399902344],
...
[336.11724853515625, 280.3601989746094, 39.456939697265625, 18.121286392211914]]
[2022/06/16 13:06:54] ppocr INFO: save vis result to ./output/table.jpg
[2022/06/16 13:06:54] ppocr INFO: Predict time of docs/table/table.jpg: 17.36806297302246
```
**注意**
- TableMaster在推理时比较慢,建议使用GPU进行使用。
<a name="4-2"></a>
### 4.2 C++推理部署
由于C++预处理后处理还未支持TableMaster,所以暂未支持
<a name="4-3"></a>
### 4.3 Serving服务化部署
暂不支持
<a name="4-4"></a>
### 4.4 更多推理部署
暂不支持
<a name="5"></a>
## 5. FAQ
## 引用
```bibtex
@article{ye2021pingan,
title={PingAn-VCGroup's Solution for ICDAR 2021 Competition on Scientific Literature Parsing Task B: Table Recognition to HTML},
author={Ye, Jiaquan and Qi, Xianbiao and He, Yelin and Chen, Yihao and Gu, Dengyi and Gao, Peng and Xiao, Rong},
journal={arXiv preprint arXiv:2105.01848},
year={2021}
}
```
# OCR Algorithms
- [1. Two-stage Algorithms](#1)
* [1.1 Text Detection Algorithms](#11)
* [1.2 Text Recognition Algorithms](#12)
- [2. End-to-end Algorithms](#2)
- [1. Two-stage Algorithms](#1-two-stage-algorithms)
- [1.1 Text Detection Algorithms](#11-text-detection-algorithms)
- [1.2 Text Recognition Algorithms](#12-text-recognition-algorithms)
- [2. End-to-end Algorithms](#2-end-to-end-algorithms)
- [3. Table Recognition Algorithms](#3-table-recognition-algorithms)
This tutorial lists the OCR algorithms supported by PaddleOCR, as well as the models and metrics of each algorithm on **English public datasets**. It is mainly used for algorithm introduction and algorithm performance comparison. For more models on other datasets including Chinese, please refer to [PP-OCR v2.0 models list](./models_list_en.md).
......@@ -95,3 +96,15 @@ Refer to [DTRB](https://arxiv.org/abs/1904.01906), the training and evaluation r
Supported end-to-end algorithms (Click the link to get the tutorial):
- [x] [PGNet](./algorithm_e2e_pgnet_en.md)
<a name="3"></a>
## 3. Table Recognition Algorithms
Supported table recognition algorithms (Click the link to get the tutorial):
- [x] [TableMaster](./algorithm_table_master_en.md)
On the PubTabNet dataset, the algorithm result is as follows:
|Model|Backbone|Config|Acc|Download link|
|---|---|---|---|---|
|TableMaster|TableResNetExtra|[configs/table/table_master.yml](../../configs/table/table_master.yml)|77.47%|[trained](https://paddleocr.bj.bcebos.com/ppstructure/models/tablemaster/table_structure_tablemaster_train.tar) / [inference model](https://paddleocr.bj.bcebos.com/ppstructure/models/tablemaster/table_structure_tablemaster_infer.tar)|
# Table Recognition Algorithm-TableMASTER
- [1. Introduction](#1-introduction)
- [2. Environment](#2-environment)
- [3. Model Training / Evaluation / Prediction](#3-model-training--evaluation--prediction)
- [4. Inference and Deployment](#4-inference-and-deployment)
- [4.1 Python Inference](#41-python-inference)
- [4.2 C++ Inference](#42-c-inference)
- [4.3 Serving](#43-serving)
- [4.4 More](#44-more)
- [5. FAQ](#5-faq)
- [Citation](#citation)
<a name="1"></a>
## 1. Introduction
Paper:
> [TableMaster: PINGAN-VCGROUP’S SOLUTION FOR ICDAR 2021 COMPETITION ON SCIENTIFIC LITERATURE PARSING TASK B: TABLE RECOGNITION TO HTML](https://arxiv.org/pdf/2105.01848.pdf)
> Ye, Jiaquan and Qi, Xianbiao and He, Yelin and Chen, Yihao and Gu, Dengyi and Gao, Peng and Xiao, Rong
> 2021
On the PubTabNet table recognition public data set, the algorithm reproduction acc is as follows:
|Model|Backbone|Cnnfig|Acc|Download link|
| --- | --- | --- | --- | --- |
|TableMaster|TableResNetExtra|[configs/table/table_master.yml](../../configs/table/table_master.yml)|77.47%|[trained model](https://paddleocr.bj.bcebos.com/ppstructure/models/tablemaster/table_structure_tablemaster_train.tar)/[inference model](https://paddleocr.bj.bcebos.com/ppstructure/models/tablemaster/table_structure_tablemaster_infer.tar)|
<a name="2"></a>
## 2. Environment
Please refer to ["Environment Preparation"](./environment_en.md) to configure the PaddleOCR environment, and refer to ["Project Clone"](./clone_en.md) to clone the project code.
<a name="3"></a>
## 3. Model Training / Evaluation / Prediction
The above TableMaster model is trained using the PubTabNet table recognition public dataset. For the download of the dataset, please refer to [table_datasets](./dataset/table_datasets_en.md).
After the data download is complete, please refer to [Text Recognition Training Tutorial](./recognition_en.md) for training. PaddleOCR has modularized the code structure, so that you only need to **replace the configuration file** to train different models.
<a name="4"></a>
## 4. Inference and Deployment
<a name="4-1"></a>
### 4.1 Python Inference
First, convert the model saved in the TableMaster table recognition training process into an inference model. Taking the model based on the TableResNetExtra backbone network and trained on the PubTabNet dataset as example ([model download link](https://paddleocr.bj.bcebos.com/contribution/table_master.tar)), you can use the following command to convert:
```shell
python3 tools/export_model.py -c configs/table/table_master.yml -o Global.pretrained_model=output/table_master/best_accuracy Global.save_inference_dir=./inference/table_master
```
**Note: **
- If you trained the model on your own dataset and adjusted the dictionary file, please pay attention to whether the `character_dict_path` in the modified configuration file is the correct dictionary file
Execute the following command for model inference:
```shell
cd ppstructure/
# When predicting all images in a folder, you can modify image_dir to a folder, such as --image_dir='docs/table'.
python3.7 table/predict_structure.py --table_model_dir=../output/table_master/table_structure_tablemaster_infer/ --table_algorithm=TableMaster --table_char_dict_path=../ppocr/utils/dict/table_master_structure_dict.txt --table_max_len=480 --image_dir=docs/table/table.jpg
```
After executing the command, the prediction results of the above image (structural information and the coordinates of each cell in the table) are printed to the screen, and the visualization of the cell coordinates is also saved. An example is as follows:
result:
```shell
[2022/06/16 13:06:54] ppocr INFO: result: ['<html>', '<body>', '<table>', '<thead>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '</thead>', '<tbody>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '</tbody>', '</table>', '</body>', '</html>'], [[72.17591094970703, 10.759100914001465, 60.29658508300781, 16.6805362701416], [161.85562133789062, 10.884308815002441, 14.9495210647583, 16.727018356323242], [277.79876708984375, 29.54340362548828, 31.490320205688477, 18.143272399902344],
...
[336.11724853515625, 280.3601989746094, 39.456939697265625, 18.121286392211914]]
[2022/06/16 13:06:54] ppocr INFO: save vis result to ./output/table.jpg
[2022/06/16 13:06:54] ppocr INFO: Predict time of docs/table/table.jpg: 17.36806297302246
```
**Note**:
- TableMaster is relatively slow during inference, and it is recommended to use GPU for use.
<a name="4-2"></a>
### 4.2 C++ Inference
Since the post-processing is not written in CPP, the TableMaster does not support CPP inference.
<a name="4-3"></a>
### 4.3 Serving
Not supported
<a name="4-4"></a>
### 4.4 More
Not supported
<a name="5"></a>
## 5. FAQ
## Citation
```bibtex
@article{ye2021pingan,
title={PingAn-VCGroup's Solution for ICDAR 2021 Competition on Scientific Literature Parsing Task B: Table Recognition to HTML},
author={Ye, Jiaquan and Qi, Xianbiao and He, Yelin and Chen, Yihao and Gu, Dengyi and Gao, Peng and Xiao, Rong},
journal={arXiv preprint arXiv:2105.01848},
year={2021}
}
```
......@@ -37,7 +37,7 @@ from .label_ops import *
from .east_process import *
from .sast_process import *
from .pg_process import *
from .gen_table_mask import *
from .table_ops import *
from .vqa import *
......
......@@ -562,171 +562,210 @@ class SRNLabelEncode(BaseRecLabelEncode):
return idx
class TableLabelEncode(object):
class TableLabelEncode(AttnLabelEncode):
""" Convert between text-label and text-index """
def __init__(self,
max_text_length,
max_elem_length,
max_cell_num,
character_dict_path,
span_weight=1.0,
replace_empty_cell_token=False,
merge_no_span_structure=False,
learn_empty_box=False,
point_num=2,
**kwargs):
self.max_text_length = max_text_length
self.max_elem_length = max_elem_length
self.max_cell_num = max_cell_num
list_character, list_elem = self.load_char_elem_dict(
character_dict_path)
list_character = self.add_special_char(list_character)
list_elem = self.add_special_char(list_elem)
self.dict_character = {}
for i, char in enumerate(list_character):
self.dict_character[char] = i
self.dict_elem = {}
for i, elem in enumerate(list_elem):
self.dict_elem[elem] = i
self.span_weight = span_weight
def load_char_elem_dict(self, character_dict_path):
list_character = []
list_elem = []
self.max_text_len = max_text_length
self.lower = False
self.learn_empty_box = learn_empty_box
self.merge_no_span_structure = merge_no_span_structure
self.replace_empty_cell_token = replace_empty_cell_token
dict_character = []
with open(character_dict_path, "rb") as fin:
lines = fin.readlines()
substr = lines[0].decode('utf-8').strip("\r\n").split("\t")
character_num = int(substr[0])
elem_num = int(substr[1])
for cno in range(1, 1 + character_num):
character = lines[cno].decode('utf-8').strip("\r\n")
list_character.append(character)
for eno in range(1 + character_num, 1 + character_num + elem_num):
elem = lines[eno].decode('utf-8').strip("\r\n")
list_elem.append(elem)
return list_character, list_elem
def add_special_char(self, list_character):
self.beg_str = "sos"
self.end_str = "eos"
list_character = [self.beg_str] + list_character + [self.end_str]
return list_character
for line in lines:
line = line.decode('utf-8').strip("\n").strip("\r\n")
dict_character.append(line)
def get_span_idx_list(self):
span_idx_list = []
for elem in self.dict_elem:
if 'span' in elem:
span_idx_list.append(self.dict_elem[elem])
return span_idx_list
dict_character = self.add_special_char(dict_character)
self.dict = {}
for i, char in enumerate(dict_character):
self.dict[char] = i
self.idx2char = {v: k for k, v in self.dict.items()}
self.character = dict_character
self.point_num = point_num
self.pad_idx = self.dict[self.beg_str]
self.start_idx = self.dict[self.beg_str]
self.end_idx = self.dict[self.end_str]
self.td_token = ['<td>', '<td', '<eb></eb>', '<td></td>']
self.empty_bbox_token_dict = {
"[]": '<eb></eb>',
"[' ']": '<eb1></eb1>',
"['<b>', ' ', '</b>']": '<eb2></eb2>',
"['\\u2028', '\\u2028']": '<eb3></eb3>',
"['<sup>', ' ', '</sup>']": '<eb4></eb4>',
"['<b>', '</b>']": '<eb5></eb5>',
"['<i>', ' ', '</i>']": '<eb6></eb6>',
"['<b>', '<i>', '</i>', '</b>']": '<eb7></eb7>',
"['<b>', '<i>', ' ', '</i>', '</b>']": '<eb8></eb8>',
"['<i>', '</i>']": '<eb9></eb9>',
"['<b>', ' ', '\\u2028', ' ', '\\u2028', ' ', '</b>']":
'<eb10></eb10>',
}
@property
def _max_text_len(self):
return self.max_text_len + 2
def __call__(self, data):
cells = data['cells']
structure = data['structure']['tokens']
structure = self.encode(structure, 'elem')
structure = data['structure']
if self.merge_no_span_structure:
structure = self._merge_no_span_structure(structure)
if self.replace_empty_cell_token:
structure = self._replace_empty_cell_token(structure, cells)
# remove empty token and add " " to span token
new_structure = []
for token in structure:
if token != '':
if 'span' in token and token[0] != ' ':
token = ' ' + token
new_structure.append(token)
# encode structure
structure = self.encode(new_structure)
if structure is None:
return None
elem_num = len(structure)
structure = [0] + structure + [len(self.dict_elem) - 1]
structure = structure + [0] * (self.max_elem_length + 2 - len(structure)
)
structure = [self.start_idx] + structure + [self.end_idx
] # add sos abd eos
structure = structure + [self.pad_idx] * (self._max_text_len -
len(structure)) # pad
structure = np.array(structure)
data['structure'] = structure
elem_char_idx1 = self.dict_elem['<td>']
elem_char_idx2 = self.dict_elem['<td']
span_idx_list = self.get_span_idx_list()
td_idx_list = np.logical_or(structure == elem_char_idx1,
structure == elem_char_idx2)
td_idx_list = np.where(td_idx_list)[0]
structure_mask = np.ones(
(self.max_elem_length + 2, 1), dtype=np.float32)
bbox_list = np.zeros((self.max_elem_length + 2, 4), dtype=np.float32)
bbox_list_mask = np.zeros(
(self.max_elem_length + 2, 1), dtype=np.float32)
img_height, img_width, img_ch = data['image'].shape
if len(span_idx_list) > 0:
span_weight = len(td_idx_list) * 1.0 / len(span_idx_list)
span_weight = min(max(span_weight, 1.0), self.span_weight)
for cno in range(len(cells)):
if 'bbox' in cells[cno]:
bbox = cells[cno]['bbox'].copy()
bbox[0] = bbox[0] * 1.0 / img_width
bbox[1] = bbox[1] * 1.0 / img_height
bbox[2] = bbox[2] * 1.0 / img_width
bbox[3] = bbox[3] * 1.0 / img_height
td_idx = td_idx_list[cno]
bbox_list[td_idx] = bbox
bbox_list_mask[td_idx] = 1.0
cand_span_idx = td_idx + 1
if cand_span_idx < (self.max_elem_length + 2):
if structure[cand_span_idx] in span_idx_list:
structure_mask[cand_span_idx] = span_weight
data['bbox_list'] = bbox_list
data['bbox_list_mask'] = bbox_list_mask
data['structure_mask'] = structure_mask
char_beg_idx = self.get_beg_end_flag_idx('beg', 'char')
char_end_idx = self.get_beg_end_flag_idx('end', 'char')
elem_beg_idx = self.get_beg_end_flag_idx('beg', 'elem')
elem_end_idx = self.get_beg_end_flag_idx('end', 'elem')
data['sp_tokens'] = np.array([
char_beg_idx, char_end_idx, elem_beg_idx, elem_end_idx,
elem_char_idx1, elem_char_idx2, self.max_text_length,
self.max_elem_length, self.max_cell_num, elem_num
])
if len(structure) > self._max_text_len:
return None
# encode box
bboxes = np.zeros(
(self._max_text_len, self.point_num * 2), dtype=np.float32)
bbox_masks = np.zeros((self._max_text_len, 1), dtype=np.float32)
bbox_idx = 0
for i, token in enumerate(structure):
if self.idx2char[token] in self.td_token:
if 'bbox' in cells[bbox_idx] and len(cells[bbox_idx][
'tokens']) > 0:
bbox = cells[bbox_idx]['bbox'].copy()
bbox = np.array(bbox, dtype=np.float32).reshape(-1)
bboxes[i] = bbox
bbox_masks[i] = 1.0
if self.learn_empty_box:
bbox_masks[i] = 1.0
bbox_idx += 1
data['bboxes'] = bboxes
data['bbox_masks'] = bbox_masks
return data
def encode(self, text, char_or_elem):
"""convert text-label into text-index.
def _merge_no_span_structure(self, structure):
"""
if char_or_elem == "char":
max_len = self.max_text_length
current_dict = self.dict_character
else:
max_len = self.max_elem_length
current_dict = self.dict_elem
if len(text) > max_len:
return None
if len(text) == 0:
if char_or_elem == "char":
return [self.dict_character['space']]
else:
return None
text_list = []
for char in text:
if char not in current_dict:
return None
text_list.append(current_dict[char])
if len(text_list) == 0:
if char_or_elem == "char":
return [self.dict_character['space']]
This code is refer from:
https://github.com/JiaquanYe/TableMASTER-mmocr/blob/master/table_recognition/data_preprocess.py
"""
new_structure = []
i = 0
while i < len(structure):
token = structure[i]
if token == '<td>':
token = '<td></td>'
i += 1
new_structure.append(token)
i += 1
return new_structure
def _replace_empty_cell_token(self, token_list, cells):
"""
This fun code is refer from:
https://github.com/JiaquanYe/TableMASTER-mmocr/blob/master/table_recognition/data_preprocess.py
"""
bbox_idx = 0
add_empty_bbox_token_list = []
for token in token_list:
if token in ['<td></td>', '<td', '<td>']:
if 'bbox' not in cells[bbox_idx].keys():
content = str(cells[bbox_idx]['tokens'])
token = self.empty_bbox_token_dict[content]
add_empty_bbox_token_list.append(token)
bbox_idx += 1
else:
return None
return text_list
add_empty_bbox_token_list.append(token)
return add_empty_bbox_token_list
def get_ignored_tokens(self, char_or_elem):
beg_idx = self.get_beg_end_flag_idx("beg", char_or_elem)
end_idx = self.get_beg_end_flag_idx("end", char_or_elem)
return [beg_idx, end_idx]
def get_beg_end_flag_idx(self, beg_or_end, char_or_elem):
if char_or_elem == "char":
if beg_or_end == "beg":
idx = np.array(self.dict_character[self.beg_str])
elif beg_or_end == "end":
idx = np.array(self.dict_character[self.end_str])
else:
assert False, "Unsupport type %s in get_beg_end_flag_idx of char" \
% beg_or_end
elif char_or_elem == "elem":
if beg_or_end == "beg":
idx = np.array(self.dict_elem[self.beg_str])
elif beg_or_end == "end":
idx = np.array(self.dict_elem[self.end_str])
else:
assert False, "Unsupport type %s in get_beg_end_flag_idx of elem" \
% beg_or_end
else:
assert False, "Unsupport type %s in char_or_elem" \
% char_or_elem
return idx
class TableMasterLabelEncode(TableLabelEncode):
""" Convert between text-label and text-index """
def __init__(self,
max_text_length,
character_dict_path,
replace_empty_cell_token=False,
merge_no_span_structure=False,
learn_empty_box=False,
point_num=2,
**kwargs):
super(TableMasterLabelEncode, self).__init__(
max_text_length, character_dict_path, replace_empty_cell_token,
merge_no_span_structure, learn_empty_box, point_num, **kwargs)
self.pad_idx = self.dict[self.pad_str]
self.unknown_idx = self.dict[self.unknown_str]
@property
def _max_text_len(self):
return self.max_text_len
def add_special_char(self, dict_character):
self.beg_str = '<SOS>'
self.end_str = '<EOS>'
self.unknown_str = '<UKN>'
self.pad_str = '<PAD>'
dict_character = dict_character
dict_character = dict_character + [
self.unknown_str, self.beg_str, self.end_str, self.pad_str
]
return dict_character
class TableBoxEncode(object):
def __init__(self, use_xywh=False, **kwargs):
self.use_xywh = use_xywh
def __call__(self, data):
img_height, img_width = data['image'].shape[:2]
bboxes = data['bboxes']
if self.use_xywh and bboxes.shape[1] == 4:
bboxes = self.xyxy2xywh(bboxes)
bboxes[:, 0::2] /= img_width
bboxes[:, 1::2] /= img_height
data['bboxes'] = bboxes
return data
def xyxy2xywh(self, bboxes):
"""
Convert coord (x1,y1,x2,y2) to (x,y,w,h).
where (x1,y1) is top-left, (x2,y2) is bottom-right.
(x,y) is bbox center and (w,h) is width and height.
:param bboxes: (x1, y1, x2, y2)
:return:
"""
new_bboxes = np.empty_like(bboxes)
new_bboxes[:, 0] = (bboxes[:, 0] + bboxes[:, 2]) / 2 # x center
new_bboxes[:, 1] = (bboxes[:, 1] + bboxes[:, 3]) / 2 # y center
new_bboxes[:, 2] = bboxes[:, 2] - bboxes[:, 0] # width
new_bboxes[:, 3] = bboxes[:, 3] - bboxes[:, 1] # height
return new_bboxes
class SARLabelEncode(BaseRecLabelEncode):
......@@ -1013,7 +1052,6 @@ class MultiLabelEncode(BaseRecLabelEncode):
use_space_char, **kwargs)
def __call__(self, data):
data_ctc = copy.deepcopy(data)
data_sar = copy.deepcopy(data)
data_out = dict()
......
......@@ -32,7 +32,7 @@ class GenTableMask(object):
self.shrink_h_max = 5
self.shrink_w_max = 5
self.mask_type = mask_type
def projection(self, erosion, h, w, spilt_threshold=0):
# 水平投影
projection_map = np.ones_like(erosion)
......@@ -48,10 +48,12 @@ class GenTableMask(object):
in_text = False # 是否遍历到了字符区内
box_list = []
for i in range(len(project_val_array)):
if in_text == False and project_val_array[i] > spilt_threshold: # 进入字符区了
if in_text == False and project_val_array[
i] > spilt_threshold: # 进入字符区了
in_text = True
start_idx = i
elif project_val_array[i] <= spilt_threshold and in_text == True: # 进入空白区了
elif project_val_array[
i] <= spilt_threshold and in_text == True: # 进入空白区了
end_idx = i
in_text = False
if end_idx - start_idx <= 2:
......@@ -70,7 +72,8 @@ class GenTableMask(object):
box_gray_img = cv2.cvtColor(box_img, cv2.COLOR_BGR2GRAY)
h, w = box_gray_img.shape
# 灰度图片进行二值化处理
ret, thresh1 = cv2.threshold(box_gray_img, 200, 255, cv2.THRESH_BINARY_INV)
ret, thresh1 = cv2.threshold(box_gray_img, 200, 255,
cv2.THRESH_BINARY_INV)
# 纵向腐蚀
if h < w:
kernel = np.ones((2, 1), np.uint8)
......@@ -95,10 +98,12 @@ class GenTableMask(object):
box_list = []
spilt_threshold = 0
for i in range(len(project_val_array)):
if in_text == False and project_val_array[i] > spilt_threshold: # 进入字符区了
if in_text == False and project_val_array[
i] > spilt_threshold: # 进入字符区了
in_text = True
start_idx = i
elif project_val_array[i] <= spilt_threshold and in_text == True: # 进入空白区了
elif project_val_array[
i] <= spilt_threshold and in_text == True: # 进入空白区了
end_idx = i
in_text = False
if end_idx - start_idx <= 2:
......@@ -120,7 +125,8 @@ class GenTableMask(object):
h_end = h
word_img = erosion[h_start:h_end + 1, :]
word_h, word_w = word_img.shape
w_split_list, w_projection_map = self.projection(word_img.T, word_w, word_h)
w_split_list, w_projection_map = self.projection(word_img.T,
word_w, word_h)
w_start, w_end = w_split_list[0][0], w_split_list[-1][1]
if h_start > 0:
h_start -= 1
......@@ -170,75 +176,54 @@ class GenTableMask(object):
for sno in range(len(split_bbox_list)):
left, top, right, bottom = split_bbox_list[sno]
left, top, right, bottom = self.shrink_bbox([left, top, right, bottom])
left, top, right, bottom = self.shrink_bbox(
[left, top, right, bottom])
if self.mask_type == 1:
mask_img[top:bottom, left:right] = 1.0
data['mask_img'] = mask_img
else:
mask_img[top:bottom, left:right, :] = (255, 255, 255)
mask_img[top:bottom, left:right, :] = (255, 255, 255)
data['image'] = mask_img
return data
class ResizeTableImage(object):
def __init__(self, max_len, **kwargs):
def __init__(self, max_len, resize_bboxes=False, infer_mode=False,
**kwargs):
super(ResizeTableImage, self).__init__()
self.max_len = max_len
self.resize_bboxes = resize_bboxes
self.infer_mode = infer_mode
def get_img_bbox(self, cells):
bbox_list = []
if len(cells) == 0:
return bbox_list
cell_num = len(cells)
for cno in range(cell_num):
if "bbox" in cells[cno]:
bbox = cells[cno]['bbox']
bbox_list.append(bbox)
return bbox_list
def resize_img_table(self, img, bbox_list, max_len):
def __call__(self, data):
img = data['image']
height, width = img.shape[0:2]
ratio = max_len / (max(height, width) * 1.0)
ratio = self.max_len / (max(height, width) * 1.0)
resize_h = int(height * ratio)
resize_w = int(width * ratio)
img_new = cv2.resize(img, (resize_w, resize_h))
bbox_list_new = []
for bno in range(len(bbox_list)):
left, top, right, bottom = bbox_list[bno].copy()
left = int(left * ratio)
top = int(top * ratio)
right = int(right * ratio)
bottom = int(bottom * ratio)
bbox_list_new.append([left, top, right, bottom])
return img_new, bbox_list_new
def __call__(self, data):
img = data['image']
if 'cells' not in data:
cells = []
else:
cells = data['cells']
bbox_list = self.get_img_bbox(cells)
img_new, bbox_list_new = self.resize_img_table(img, bbox_list, self.max_len)
data['image'] = img_new
cell_num = len(cells)
bno = 0
for cno in range(cell_num):
if "bbox" in data['cells'][cno]:
data['cells'][cno]['bbox'] = bbox_list_new[bno]
bno += 1
resize_img = cv2.resize(img, (resize_w, resize_h))
if self.resize_bboxes and not self.infer_mode:
data['bboxes'] = data['bboxes'] * ratio
data['image'] = resize_img
data['src_img'] = img
data['shape'] = np.array([resize_h, resize_w, ratio, ratio])
data['max_len'] = self.max_len
return data
class PaddingTableImage(object):
def __init__(self, **kwargs):
def __init__(self, size, **kwargs):
super(PaddingTableImage, self).__init__()
self.size = size
def __call__(self, data):
img = data['image']
max_len = data['max_len']
padding_img = np.zeros((max_len, max_len, 3), dtype=np.float32)
pad_h, pad_w = self.size
padding_img = np.zeros((pad_h, pad_w, 3), dtype=np.float32)
height, width = img.shape[0:2]
padding_img[0:height, 0:width, :] = img.copy()
data['image'] = padding_img
shape = data['shape'].tolist()
shape.extend([pad_h, pad_w])
data['shape'] = np.array(shape)
return data
\ No newline at end of file
......@@ -16,6 +16,7 @@ import os
import random
from paddle.io import Dataset
import json
from copy import deepcopy
from .imaug import transform, create_operators
......@@ -29,33 +30,63 @@ class PubTabDataSet(Dataset):
dataset_config = config[mode]['dataset']
loader_config = config[mode]['loader']
label_file_path = dataset_config.pop('label_file_path')
label_file_list = dataset_config.pop('label_file_list')
data_source_num = len(label_file_list)
ratio_list = dataset_config.get("ratio_list", [1.0])
if isinstance(ratio_list, (float, int)):
ratio_list = [float(ratio_list)] * int(data_source_num)
assert len(
ratio_list
) == data_source_num, "The length of ratio_list should be the same as the file_list."
self.data_dir = dataset_config['data_dir']
self.do_shuffle = loader_config['shuffle']
self.do_hard_select = False
if 'hard_select' in loader_config:
self.do_hard_select = loader_config['hard_select']
self.hard_prob = loader_config['hard_prob']
if self.do_hard_select:
self.img_select_prob = self.load_hard_select_prob()
self.table_select_type = None
if 'table_select_type' in loader_config:
self.table_select_type = loader_config['table_select_type']
self.table_select_prob = loader_config['table_select_prob']
self.seed = seed
logger.info("Initialize indexs of datasets:%s" % label_file_path)
with open(label_file_path, "rb") as f:
self.data_lines = f.readlines()
self.data_idx_order_list = list(range(len(self.data_lines)))
if mode.lower() == "train":
self.mode = mode.lower()
logger.info("Initialize indexs of datasets:%s" % label_file_list)
self.data_lines = self.get_image_info_list(label_file_list, ratio_list)
# self.check(config['Global']['max_text_length'])
if mode.lower() == "train" and self.do_shuffle:
self.shuffle_data_random()
self.ops = create_operators(dataset_config['transforms'], global_config)
ratio_list = dataset_config.get("ratio_list", [1.0])
self.need_reset = True in [x < 1 for x in ratio_list]
def get_image_info_list(self, file_list, ratio_list):
if isinstance(file_list, str):
file_list = [file_list]
data_lines = []
for idx, file in enumerate(file_list):
with open(file, "rb") as f:
lines = f.readlines()
if self.mode == "train" or ratio_list[idx] < 1.0:
random.seed(self.seed)
lines = random.sample(lines,
round(len(lines) * ratio_list[idx]))
data_lines.extend(lines)
return data_lines
def check(self, max_text_length):
data_lines = []
for line in self.data_lines:
data_line = line.decode('utf-8').strip("\n")
info = json.loads(data_line)
file_name = info['filename']
cells = info['html']['cells'].copy()
structure = info['html']['structure']['tokens'].copy()
img_path = os.path.join(self.data_dir, file_name)
if not os.path.exists(img_path):
self.logger.warning("{} does not exist!".format(img_path))
continue
if len(structure) == 0 or len(structure) > max_text_length:
continue
# data = {'img_path': img_path, 'cells': cells, 'structure':structure,'file_name':file_name}
data_lines.append(line)
self.data_lines = data_lines
def shuffle_data_random(self):
if self.do_shuffle:
random.seed(self.seed)
......@@ -68,47 +99,35 @@ class PubTabDataSet(Dataset):
data_line = data_line.decode('utf-8').strip("\n")
info = json.loads(data_line)
file_name = info['filename']
select_flag = True
if self.do_hard_select:
prob = self.img_select_prob[file_name]
if prob < random.uniform(0, 1):
select_flag = False
if self.table_select_type:
structure = info['html']['structure']['tokens'].copy()
structure_str = ''.join(structure)
table_type = "simple"
if 'colspan' in structure_str or 'rowspan' in structure_str:
table_type = "complex"
if table_type == "complex":
if self.table_select_prob < random.uniform(0, 1):
select_flag = False
if select_flag:
cells = info['html']['cells'].copy()
structure = info['html']['structure'].copy()
img_path = os.path.join(self.data_dir, file_name)
data = {
'img_path': img_path,
'cells': cells,
'structure': structure
}
if not os.path.exists(img_path):
raise Exception("{} does not exist!".format(img_path))
with open(data['img_path'], 'rb') as f:
img = f.read()
data['image'] = img
outs = transform(data, self.ops)
else:
outs = None
except Exception as e:
cells = info['html']['cells'].copy()
structure = info['html']['structure']['tokens'].copy()
img_path = os.path.join(self.data_dir, file_name)
if not os.path.exists(img_path):
raise Exception("{} does not exist!".format(img_path))
data = {
'img_path': img_path,
'cells': cells,
'structure': structure,
'file_name': file_name
}
with open(data['img_path'], 'rb') as f:
img = f.read()
data['image'] = img
outs = transform(data, self.ops)
except:
import traceback
err = traceback.format_exc()
self.logger.error(
"When parsing line {}, error happened with msg: {}".format(
data_line, e))
data_line, err))
outs = None
if outs is None:
return self.__getitem__(np.random.randint(self.__len__()))
rnd_idx = np.random.randint(self.__len__(
)) if self.mode == "train" else (idx + 1) % self.__len__()
return self.__getitem__(rnd_idx)
return outs
def __len__(self):
return len(self.data_idx_order_list)
return len(self.data_lines)
......@@ -51,7 +51,7 @@ from .combined_loss import CombinedLoss
# table loss
from .table_att_loss import TableAttentionLoss
from .table_master_loss import TableMasterLoss
# vqa token loss
from .vqa_token_layoutlm_loss import VQASerTokenLayoutLMLoss
......@@ -61,7 +61,8 @@ def build_loss(config):
'DBLoss', 'PSELoss', 'EASTLoss', 'SASTLoss', 'FCELoss', 'CTCLoss',
'ClsLoss', 'AttentionLoss', 'SRNLoss', 'PGLoss', 'CombinedLoss',
'CELoss', 'TableAttentionLoss', 'SARLoss', 'AsterLoss', 'SDMGRLoss',
'VQASerTokenLayoutLMLoss', 'LossFromOutput', 'PRENLoss', 'MultiLoss'
'VQASerTokenLayoutLMLoss', 'LossFromOutput', 'PRENLoss', 'MultiLoss',
'TableMasterLoss'
]
config = copy.deepcopy(config)
module_name = config.pop('name')
......
......@@ -20,15 +20,21 @@ import paddle
from paddle import nn
from paddle.nn import functional as F
class TableAttentionLoss(nn.Layer):
def __init__(self, structure_weight, loc_weight, use_giou=False, giou_weight=1.0, **kwargs):
def __init__(self,
structure_weight,
loc_weight,
use_giou=False,
giou_weight=1.0,
**kwargs):
super(TableAttentionLoss, self).__init__()
self.loss_func = nn.CrossEntropyLoss(weight=None, reduction='none')
self.structure_weight = structure_weight
self.loc_weight = loc_weight
self.use_giou = use_giou
self.giou_weight = giou_weight
def giou_loss(self, preds, bbox, eps=1e-7, reduction='mean'):
'''
:param preds:[[x1,y1,x2,y2], [x1,y1,x2,y2],,,]
......@@ -47,9 +53,10 @@ class TableAttentionLoss(nn.Layer):
inters = iw * ih
# union
uni = (preds[:, 2] - preds[:, 0] + 1e-3) * (preds[:, 3] - preds[:, 1] + 1e-3
) + (bbox[:, 2] - bbox[:, 0] + 1e-3) * (
bbox[:, 3] - bbox[:, 1] + 1e-3) - inters + eps
uni = (preds[:, 2] - preds[:, 0] + 1e-3) * (
preds[:, 3] - preds[:, 1] + 1e-3) + (bbox[:, 2] - bbox[:, 0] + 1e-3
) * (bbox[:, 3] - bbox[:, 1] +
1e-3) - inters + eps
# ious
ious = inters / uni
......@@ -79,30 +86,34 @@ class TableAttentionLoss(nn.Layer):
structure_probs = predicts['structure_probs']
structure_targets = batch[1].astype("int64")
structure_targets = structure_targets[:, 1:]
if len(batch) == 6:
structure_mask = batch[5].astype("int64")
structure_mask = structure_mask[:, 1:]
structure_mask = paddle.reshape(structure_mask, [-1])
structure_probs = paddle.reshape(structure_probs, [-1, structure_probs.shape[-1]])
structure_probs = paddle.reshape(structure_probs,
[-1, structure_probs.shape[-1]])
structure_targets = paddle.reshape(structure_targets, [-1])
structure_loss = self.loss_func(structure_probs, structure_targets)
if len(batch) == 6:
structure_loss = structure_loss * structure_mask
# structure_loss = paddle.sum(structure_loss) * self.structure_weight
structure_loss = paddle.mean(structure_loss) * self.structure_weight
loc_preds = predicts['loc_preds']
loc_targets = batch[2].astype("float32")
loc_targets_mask = batch[4].astype("float32")
loc_targets_mask = batch[3].astype("float32")
loc_targets = loc_targets[:, 1:, :]
loc_targets_mask = loc_targets_mask[:, 1:, :]
loc_loss = F.mse_loss(loc_preds * loc_targets_mask, loc_targets) * self.loc_weight
loc_loss = F.mse_loss(loc_preds * loc_targets_mask,
loc_targets) * self.loc_weight
if self.use_giou:
loc_loss_giou = self.giou_loss(loc_preds * loc_targets_mask, loc_targets) * self.giou_weight
loc_loss_giou = self.giou_loss(loc_preds * loc_targets_mask,
loc_targets) * self.giou_weight
total_loss = structure_loss + loc_loss + loc_loss_giou
return {'loss':total_loss, "structure_loss":structure_loss, "loc_loss":loc_loss, "loc_loss_giou":loc_loss_giou}
return {
'loss': total_loss,
"structure_loss": structure_loss,
"loc_loss": loc_loss,
"loc_loss_giou": loc_loss_giou
}
else:
total_loss = structure_loss + loc_loss
return {'loss':total_loss, "structure_loss":structure_loss, "loc_loss":loc_loss}
\ No newline at end of file
total_loss = structure_loss + loc_loss
return {
'loss': total_loss,
"structure_loss": structure_loss,
"loc_loss": loc_loss
}
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This code is refer from:
https://github.com/JiaquanYe/TableMASTER-mmocr/tree/master/mmocr/models/textrecog/losses
"""
import paddle
from paddle import nn
class TableMasterLoss(nn.Layer):
def __init__(self, ignore_index=-1):
super(TableMasterLoss, self).__init__()
self.structure_loss = nn.CrossEntropyLoss(
ignore_index=ignore_index, reduction='mean')
self.box_loss = nn.L1Loss(reduction='sum')
self.eps = 1e-12
def forward(self, predicts, batch):
# structure_loss
structure_probs = predicts['structure_probs']
structure_targets = batch[1]
structure_targets = structure_targets[:, 1:]
structure_probs = structure_probs.reshape(
[-1, structure_probs.shape[-1]])
structure_targets = structure_targets.reshape([-1])
structure_loss = self.structure_loss(structure_probs, structure_targets)
structure_loss = structure_loss.mean()
losses = dict(structure_loss=structure_loss)
# box loss
bboxes_preds = predicts['loc_preds']
bboxes_targets = batch[2][:, 1:, :]
bbox_masks = batch[3][:, 1:]
# mask empty-bbox or non-bbox structure token's bbox.
masked_bboxes_preds = bboxes_preds * bbox_masks
masked_bboxes_targets = bboxes_targets * bbox_masks
# horizon loss (x and width)
horizon_sum_loss = self.box_loss(masked_bboxes_preds[:, :, 0::2],
masked_bboxes_targets[:, :, 0::2])
horizon_loss = horizon_sum_loss / (bbox_masks.sum() + self.eps)
# vertical loss (y and height)
vertical_sum_loss = self.box_loss(masked_bboxes_preds[:, :, 1::2],
masked_bboxes_targets[:, :, 1::2])
vertical_loss = vertical_sum_loss / (bbox_masks.sum() + self.eps)
horizon_loss = horizon_loss.mean()
vertical_loss = vertical_loss.mean()
all_loss = structure_loss + horizon_loss + vertical_loss
losses.update({
'loss': all_loss,
'horizon_bbox_loss': horizon_loss,
'vertical_bbox_loss': vertical_loss
})
return losses
......@@ -12,29 +12,30 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
from ppocr.metrics.det_metric import DetMetric
class TableMetric(object):
def __init__(self, main_indicator='acc', **kwargs):
class TableStructureMetric(object):
def __init__(self, main_indicator='acc', eps=1e-6, **kwargs):
self.main_indicator = main_indicator
self.eps = 1e-5
self.eps = eps
self.reset()
def __call__(self, pred, batch, *args, **kwargs):
structure_probs = pred['structure_probs'].numpy()
structure_labels = batch[1]
def __call__(self, pred_label, batch=None, *args, **kwargs):
preds, labels = pred_label
pred_structure_batch_list = preds['structure_batch_list']
gt_structure_batch_list = labels['structure_batch_list']
correct_num = 0
all_num = 0
structure_probs = np.argmax(structure_probs, axis=2)
structure_labels = structure_labels[:, 1:]
batch_size = structure_probs.shape[0]
for bno in range(batch_size):
all_num += 1
if (structure_probs[bno] == structure_labels[bno]).all():
for (pred, pred_conf), target in zip(pred_structure_batch_list,
gt_structure_batch_list):
pred_str = ''.join(pred)
target_str = ''.join(target)
if pred_str == target_str:
correct_num += 1
all_num += 1
self.correct_num += correct_num
self.all_num += all_num
return {'acc': correct_num * 1.0 / (all_num + self.eps), }
def get_metric(self):
"""
......@@ -49,3 +50,89 @@ class TableMetric(object):
def reset(self):
self.correct_num = 0
self.all_num = 0
self.len_acc_num = 0
self.token_nums = 0
self.anys_dict = dict()
class TableMetric(object):
def __init__(self,
main_indicator='acc',
compute_bbox_metric=False,
point_num=2,
**kwargs):
"""
@param sub_metrics: configs of sub_metric
@param main_matric: main_matric for save best_model
@param kwargs:
"""
self.structure_metric = TableStructureMetric()
self.bbox_metric = DetMetric() if compute_bbox_metric else None
self.main_indicator = main_indicator
self.point_num = point_num
self.reset()
def __call__(self, pred_label, batch=None, *args, **kwargs):
self.structure_metric(pred_label)
if self.bbox_metric is not None:
self.bbox_metric(*self.prepare_bbox_metric_input(pred_label))
def prepare_bbox_metric_input(self, pred_label):
pred_bbox_batch_list = []
gt_ignore_tags_batch_list = []
gt_bbox_batch_list = []
preds, labels = pred_label
batch_num = len(preds['bbox_batch_list'])
for batch_idx in range(batch_num):
# pred
pred_bbox_list = [
self.format_box(pred_box)
for pred_box in preds['bbox_batch_list'][batch_idx]
]
pred_bbox_batch_list.append({'points': pred_bbox_list})
# gt
gt_bbox_list = []
gt_ignore_tags_list = []
for gt_box in labels['bbox_batch_list'][batch_idx]:
gt_bbox_list.append(self.format_box(gt_box))
gt_ignore_tags_list.append(0)
gt_bbox_batch_list.append(gt_bbox_list)
gt_ignore_tags_batch_list.append(gt_ignore_tags_list)
return [
pred_bbox_batch_list,
[0, 0, gt_bbox_batch_list, gt_ignore_tags_batch_list]
]
def get_metric(self):
structure_metric = self.structure_metric.get_metric()
if self.bbox_metric is None:
return structure_metric
bbox_metric = self.bbox_metric.get_metric()
if self.main_indicator == self.bbox_metric.main_indicator:
output = bbox_metric
for sub_key in structure_metric:
output["structure_metric_{}".format(
sub_key)] = structure_metric[sub_key]
else:
output = structure_metric
for sub_key in bbox_metric:
output["bbox_metric_{}".format(sub_key)] = bbox_metric[sub_key]
return output
def reset(self):
self.structure_metric.reset()
if self.bbox_metric is not None:
self.bbox_metric.reset()
def format_box(self, box):
if self.point_num == 2:
x1, y1, x2, y2 = box
box = [[x1, y1], [x2, y1], [x2, y2], [x1, y2]]
elif self.point_num == 4:
x1, y1, x2, y2, x3, y3, x4, y4 = box
box = [[x1, y1], [x2, y2], [x3, y3], [x4, y4]]
return box
......@@ -22,6 +22,9 @@ def build_backbone(config, model_type):
from .det_resnet_vd import ResNet_vd
from .det_resnet_vd_sast import ResNet_SAST
support_dict = ["MobileNetV3", "ResNet", "ResNet_vd", "ResNet_SAST"]
if model_type == "table":
from .table_master_resnet import TableResNetExtra
support_dict.append('TableResNetExtra')
elif model_type == "rec" or model_type == "cls":
from .rec_mobilenet_v3 import MobileNetV3
from .rec_resnet_vd import ResNet
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This code is refer from:
https://github.com/JiaquanYe/TableMASTER-mmocr/blob/master/mmocr/models/textrecog/backbones/table_resnet_extra.py
"""
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
class BasicBlock(nn.Layer):
expansion = 1
def __init__(self,
inplanes,
planes,
stride=1,
downsample=None,
gcb_config=None):
super(BasicBlock, self).__init__()
self.conv1 = nn.Conv2D(
inplanes,
planes,
kernel_size=3,
stride=stride,
padding=1,
bias_attr=False)
self.bn1 = nn.BatchNorm2D(planes, momentum=0.9)
self.relu = nn.ReLU()
self.conv2 = nn.Conv2D(
planes, planes, kernel_size=3, stride=1, padding=1, bias_attr=False)
self.bn2 = nn.BatchNorm2D(planes, momentum=0.9)
self.downsample = downsample
self.stride = stride
self.gcb_config = gcb_config
if self.gcb_config is not None:
gcb_ratio = gcb_config['ratio']
gcb_headers = gcb_config['headers']
att_scale = gcb_config['att_scale']
fusion_type = gcb_config['fusion_type']
self.context_block = MultiAspectGCAttention(
inplanes=planes,
ratio=gcb_ratio,
headers=gcb_headers,
att_scale=att_scale,
fusion_type=fusion_type)
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.gcb_config is not None:
out = self.context_block(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.relu(out)
return out
def get_gcb_config(gcb_config, layer):
if gcb_config is None or not gcb_config['layers'][layer]:
return None
else:
return gcb_config
class TableResNetExtra(nn.Layer):
def __init__(self, layers, in_channels=3, gcb_config=None):
assert len(layers) >= 4
super(TableResNetExtra, self).__init__()
self.inplanes = 128
self.conv1 = nn.Conv2D(
in_channels,
64,
kernel_size=3,
stride=1,
padding=1,
bias_attr=False)
self.bn1 = nn.BatchNorm2D(64)
self.relu1 = nn.ReLU()
self.conv2 = nn.Conv2D(
64, 128, kernel_size=3, stride=1, padding=1, bias_attr=False)
self.bn2 = nn.BatchNorm2D(128)
self.relu2 = nn.ReLU()
self.maxpool1 = nn.MaxPool2D(kernel_size=2, stride=2)
self.layer1 = self._make_layer(
BasicBlock,
256,
layers[0],
stride=1,
gcb_config=get_gcb_config(gcb_config, 0))
self.conv3 = nn.Conv2D(
256, 256, kernel_size=3, stride=1, padding=1, bias_attr=False)
self.bn3 = nn.BatchNorm2D(256)
self.relu3 = nn.ReLU()
self.maxpool2 = nn.MaxPool2D(kernel_size=2, stride=2)
self.layer2 = self._make_layer(
BasicBlock,
256,
layers[1],
stride=1,
gcb_config=get_gcb_config(gcb_config, 1))
self.conv4 = nn.Conv2D(
256, 256, kernel_size=3, stride=1, padding=1, bias_attr=False)
self.bn4 = nn.BatchNorm2D(256)
self.relu4 = nn.ReLU()
self.maxpool3 = nn.MaxPool2D(kernel_size=2, stride=2)
self.layer3 = self._make_layer(
BasicBlock,
512,
layers[2],
stride=1,
gcb_config=get_gcb_config(gcb_config, 2))
self.conv5 = nn.Conv2D(
512, 512, kernel_size=3, stride=1, padding=1, bias_attr=False)
self.bn5 = nn.BatchNorm2D(512)
self.relu5 = nn.ReLU()
self.layer4 = self._make_layer(
BasicBlock,
512,
layers[3],
stride=1,
gcb_config=get_gcb_config(gcb_config, 3))
self.conv6 = nn.Conv2D(
512, 512, kernel_size=3, stride=1, padding=1, bias_attr=False)
self.bn6 = nn.BatchNorm2D(512)
self.relu6 = nn.ReLU()
self.out_channels = [256, 256, 512]
def _make_layer(self, block, planes, blocks, stride=1, gcb_config=None):
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
nn.Conv2D(
self.inplanes,
planes * block.expansion,
kernel_size=1,
stride=stride,
bias_attr=False),
nn.BatchNorm2D(planes * block.expansion), )
layers = []
layers.append(
block(
self.inplanes,
planes,
stride,
downsample,
gcb_config=gcb_config))
self.inplanes = planes * block.expansion
for _ in range(1, blocks):
layers.append(block(self.inplanes, planes))
return nn.Sequential(*layers)
def forward(self, x):
f = []
x = self.conv1(x)
x = self.bn1(x)
x = self.relu1(x)
x = self.conv2(x)
x = self.bn2(x)
x = self.relu2(x)
x = self.maxpool1(x)
x = self.layer1(x)
x = self.conv3(x)
x = self.bn3(x)
x = self.relu3(x)
f.append(x)
x = self.maxpool2(x)
x = self.layer2(x)
x = self.conv4(x)
x = self.bn4(x)
x = self.relu4(x)
f.append(x)
x = self.maxpool3(x)
x = self.layer3(x)
x = self.conv5(x)
x = self.bn5(x)
x = self.relu5(x)
x = self.layer4(x)
x = self.conv6(x)
x = self.bn6(x)
x = self.relu6(x)
f.append(x)
return f
class MultiAspectGCAttention(nn.Layer):
def __init__(self,
inplanes,
ratio,
headers,
pooling_type='att',
att_scale=False,
fusion_type='channel_add'):
super(MultiAspectGCAttention, self).__init__()
assert pooling_type in ['avg', 'att']
assert fusion_type in ['channel_add', 'channel_mul', 'channel_concat']
assert inplanes % headers == 0 and inplanes >= 8 # inplanes must be divided by headers evenly
self.headers = headers
self.inplanes = inplanes
self.ratio = ratio
self.planes = int(inplanes * ratio)
self.pooling_type = pooling_type
self.fusion_type = fusion_type
self.att_scale = False
self.single_header_inplanes = int(inplanes / headers)
if pooling_type == 'att':
self.conv_mask = nn.Conv2D(
self.single_header_inplanes, 1, kernel_size=1)
self.softmax = nn.Softmax(axis=2)
else:
self.avg_pool = nn.AdaptiveAvgPool2D(1)
if fusion_type == 'channel_add':
self.channel_add_conv = nn.Sequential(
nn.Conv2D(
self.inplanes, self.planes, kernel_size=1),
nn.LayerNorm([self.planes, 1, 1]),
nn.ReLU(),
nn.Conv2D(
self.planes, self.inplanes, kernel_size=1))
elif fusion_type == 'channel_concat':
self.channel_concat_conv = nn.Sequential(
nn.Conv2D(
self.inplanes, self.planes, kernel_size=1),
nn.LayerNorm([self.planes, 1, 1]),
nn.ReLU(),
nn.Conv2D(
self.planes, self.inplanes, kernel_size=1))
# for concat
self.cat_conv = nn.Conv2D(
2 * self.inplanes, self.inplanes, kernel_size=1)
elif fusion_type == 'channel_mul':
self.channel_mul_conv = nn.Sequential(
nn.Conv2D(
self.inplanes, self.planes, kernel_size=1),
nn.LayerNorm([self.planes, 1, 1]),
nn.ReLU(),
nn.Conv2D(
self.planes, self.inplanes, kernel_size=1))
def spatial_pool(self, x):
batch, channel, height, width = x.shape
if self.pooling_type == 'att':
# [N*headers, C', H , W] C = headers * C'
x = x.reshape([
batch * self.headers, self.single_header_inplanes, height, width
])
input_x = x
# [N*headers, C', H * W] C = headers * C'
# input_x = input_x.view(batch, channel, height * width)
input_x = input_x.reshape([
batch * self.headers, self.single_header_inplanes,
height * width
])
# [N*headers, 1, C', H * W]
input_x = input_x.unsqueeze(1)
# [N*headers, 1, H, W]
context_mask = self.conv_mask(x)
# [N*headers, 1, H * W]
context_mask = context_mask.reshape(
[batch * self.headers, 1, height * width])
# scale variance
if self.att_scale and self.headers > 1:
context_mask = context_mask / paddle.sqrt(
self.single_header_inplanes)
# [N*headers, 1, H * W]
context_mask = self.softmax(context_mask)
# [N*headers, 1, H * W, 1]
context_mask = context_mask.unsqueeze(-1)
# [N*headers, 1, C', 1] = [N*headers, 1, C', H * W] * [N*headers, 1, H * W, 1]
context = paddle.matmul(input_x, context_mask)
# [N, headers * C', 1, 1]
context = context.reshape(
[batch, self.headers * self.single_header_inplanes, 1, 1])
else:
# [N, C, 1, 1]
context = self.avg_pool(x)
return context
def forward(self, x):
# [N, C, 1, 1]
context = self.spatial_pool(x)
out = x
if self.fusion_type == 'channel_mul':
# [N, C, 1, 1]
channel_mul_term = F.sigmoid(self.channel_mul_conv(context))
out = out * channel_mul_term
elif self.fusion_type == 'channel_add':
# [N, C, 1, 1]
channel_add_term = self.channel_add_conv(context)
out = out + channel_add_term
else:
# [N, C, 1, 1]
channel_concat_term = self.channel_concat_conv(context)
# use concat
_, C1, _, _ = channel_concat_term.shape
N, C2, H, W = out.shape
out = paddle.concat(
[out, channel_concat_term.expand([-1, -1, H, W])], axis=1)
out = self.cat_conv(out)
out = F.layer_norm(out, [self.inplanes, H, W])
out = F.relu(out)
return out
......@@ -42,12 +42,13 @@ def build_head(config):
from .kie_sdmgr_head import SDMGRHead
from .table_att_head import TableAttentionHead
from .table_master_head import TableMasterHead
support_dict = [
'DBHead', 'PSEHead', 'FCEHead', 'EASTHead', 'SASTHead', 'CTCHead',
'ClsHead', 'AttentionHead', 'SRNHead', 'PGHead', 'Transformer',
'TableAttentionHead', 'SARHead', 'AsterHead', 'SDMGRHead', 'PRENHead',
'MultiHead', 'ABINetHead'
'MultiHead', 'ABINetHead', 'TableMasterHead'
]
#table head
......
......@@ -21,6 +21,8 @@ import paddle.nn as nn
import paddle.nn.functional as F
import numpy as np
from .rec_att_head import AttentionGRUCell
class TableAttentionHead(nn.Layer):
def __init__(self,
......@@ -28,21 +30,19 @@ class TableAttentionHead(nn.Layer):
hidden_size,
loc_type,
in_max_len=488,
max_text_length=100,
max_elem_length=800,
max_cell_num=500,
max_text_length=800,
out_channels=30,
point_num=2,
**kwargs):
super(TableAttentionHead, self).__init__()
self.input_size = in_channels[-1]
self.hidden_size = hidden_size
self.elem_num = 30
self.out_channels = out_channels
self.max_text_length = max_text_length
self.max_elem_length = max_elem_length
self.max_cell_num = max_cell_num
self.structure_attention_cell = AttentionGRUCell(
self.input_size, hidden_size, self.elem_num, use_gru=False)
self.structure_generator = nn.Linear(hidden_size, self.elem_num)
self.input_size, hidden_size, self.out_channels, use_gru=False)
self.structure_generator = nn.Linear(hidden_size, self.out_channels)
self.loc_type = loc_type
self.in_max_len = in_max_len
......@@ -50,12 +50,13 @@ class TableAttentionHead(nn.Layer):
self.loc_generator = nn.Linear(hidden_size, 4)
else:
if self.in_max_len == 640:
self.loc_fea_trans = nn.Linear(400, self.max_elem_length + 1)
self.loc_fea_trans = nn.Linear(400, self.max_text_length + 1)
elif self.in_max_len == 800:
self.loc_fea_trans = nn.Linear(625, self.max_elem_length + 1)
self.loc_fea_trans = nn.Linear(625, self.max_text_length + 1)
else:
self.loc_fea_trans = nn.Linear(256, self.max_elem_length + 1)
self.loc_generator = nn.Linear(self.input_size + hidden_size, 4)
self.loc_fea_trans = nn.Linear(256, self.max_text_length + 1)
self.loc_generator = nn.Linear(self.input_size + hidden_size,
point_num * 2)
def _char_to_onehot(self, input_char, onehot_dim):
input_ont_hot = F.one_hot(input_char, onehot_dim)
......@@ -77,9 +78,9 @@ class TableAttentionHead(nn.Layer):
output_hiddens = []
if self.training and targets is not None:
structure = targets[0]
for i in range(self.max_elem_length + 1):
for i in range(self.max_text_length + 1):
elem_onehots = self._char_to_onehot(
structure[:, i], onehot_dim=self.elem_num)
structure[:, i], onehot_dim=self.out_channels)
(outputs, hidden), alpha = self.structure_attention_cell(
hidden, fea, elem_onehots)
output_hiddens.append(paddle.unsqueeze(outputs, axis=1))
......@@ -102,11 +103,11 @@ class TableAttentionHead(nn.Layer):
elem_onehots = None
outputs = None
alpha = None
max_elem_length = paddle.to_tensor(self.max_elem_length)
max_text_length = paddle.to_tensor(self.max_text_length)
i = 0
while i < max_elem_length + 1:
while i < max_text_length + 1:
elem_onehots = self._char_to_onehot(
temp_elem, onehot_dim=self.elem_num)
temp_elem, onehot_dim=self.out_channels)
(outputs, hidden), alpha = self.structure_attention_cell(
hidden, fea, elem_onehots)
output_hiddens.append(paddle.unsqueeze(outputs, axis=1))
......@@ -128,119 +129,3 @@ class TableAttentionHead(nn.Layer):
loc_preds = self.loc_generator(loc_concat)
loc_preds = F.sigmoid(loc_preds)
return {'structure_probs': structure_probs, 'loc_preds': loc_preds}
class AttentionGRUCell(nn.Layer):
def __init__(self, input_size, hidden_size, num_embeddings, use_gru=False):
super(AttentionGRUCell, self).__init__()
self.i2h = nn.Linear(input_size, hidden_size, bias_attr=False)
self.h2h = nn.Linear(hidden_size, hidden_size)
self.score = nn.Linear(hidden_size, 1, bias_attr=False)
self.rnn = nn.GRUCell(
input_size=input_size + num_embeddings, hidden_size=hidden_size)
self.hidden_size = hidden_size
def forward(self, prev_hidden, batch_H, char_onehots):
batch_H_proj = self.i2h(batch_H)
prev_hidden_proj = paddle.unsqueeze(self.h2h(prev_hidden), axis=1)
res = paddle.add(batch_H_proj, prev_hidden_proj)
res = paddle.tanh(res)
e = self.score(res)
alpha = F.softmax(e, axis=1)
alpha = paddle.transpose(alpha, [0, 2, 1])
context = paddle.squeeze(paddle.mm(alpha, batch_H), axis=1)
concat_context = paddle.concat([context, char_onehots], 1)
cur_hidden = self.rnn(concat_context, prev_hidden)
return cur_hidden, alpha
class AttentionLSTM(nn.Layer):
def __init__(self, in_channels, out_channels, hidden_size, **kwargs):
super(AttentionLSTM, self).__init__()
self.input_size = in_channels
self.hidden_size = hidden_size
self.num_classes = out_channels
self.attention_cell = AttentionLSTMCell(
in_channels, hidden_size, out_channels, use_gru=False)
self.generator = nn.Linear(hidden_size, out_channels)
def _char_to_onehot(self, input_char, onehot_dim):
input_ont_hot = F.one_hot(input_char, onehot_dim)
return input_ont_hot
def forward(self, inputs, targets=None, batch_max_length=25):
batch_size = inputs.shape[0]
num_steps = batch_max_length
hidden = (paddle.zeros((batch_size, self.hidden_size)), paddle.zeros(
(batch_size, self.hidden_size)))
output_hiddens = []
if targets is not None:
for i in range(num_steps):
# one-hot vectors for a i-th char
char_onehots = self._char_to_onehot(
targets[:, i], onehot_dim=self.num_classes)
hidden, alpha = self.attention_cell(hidden, inputs,
char_onehots)
hidden = (hidden[1][0], hidden[1][1])
output_hiddens.append(paddle.unsqueeze(hidden[0], axis=1))
output = paddle.concat(output_hiddens, axis=1)
probs = self.generator(output)
else:
targets = paddle.zeros(shape=[batch_size], dtype="int32")
probs = None
for i in range(num_steps):
char_onehots = self._char_to_onehot(
targets, onehot_dim=self.num_classes)
hidden, alpha = self.attention_cell(hidden, inputs,
char_onehots)
probs_step = self.generator(hidden[0])
hidden = (hidden[1][0], hidden[1][1])
if probs is None:
probs = paddle.unsqueeze(probs_step, axis=1)
else:
probs = paddle.concat(
[probs, paddle.unsqueeze(
probs_step, axis=1)], axis=1)
next_input = probs_step.argmax(axis=1)
targets = next_input
return probs
class AttentionLSTMCell(nn.Layer):
def __init__(self, input_size, hidden_size, num_embeddings, use_gru=False):
super(AttentionLSTMCell, self).__init__()
self.i2h = nn.Linear(input_size, hidden_size, bias_attr=False)
self.h2h = nn.Linear(hidden_size, hidden_size)
self.score = nn.Linear(hidden_size, 1, bias_attr=False)
if not use_gru:
self.rnn = nn.LSTMCell(
input_size=input_size + num_embeddings, hidden_size=hidden_size)
else:
self.rnn = nn.GRUCell(
input_size=input_size + num_embeddings, hidden_size=hidden_size)
self.hidden_size = hidden_size
def forward(self, prev_hidden, batch_H, char_onehots):
batch_H_proj = self.i2h(batch_H)
prev_hidden_proj = paddle.unsqueeze(self.h2h(prev_hidden[0]), axis=1)
res = paddle.add(batch_H_proj, prev_hidden_proj)
res = paddle.tanh(res)
e = self.score(res)
alpha = F.softmax(e, axis=1)
alpha = paddle.transpose(alpha, [0, 2, 1])
context = paddle.squeeze(paddle.mm(alpha, batch_H), axis=1)
concat_context = paddle.concat([context, char_onehots], 1)
cur_hidden = self.rnn(concat_context, prev_hidden)
return cur_hidden, alpha
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This code is refer from:
https://github.com/JiaquanYe/TableMASTER-mmocr/blob/master/mmocr/models/textrecog/decoders/master_decoder.py
"""
import copy
import math
import paddle
from paddle import nn
from paddle.nn import functional as F
class TableMasterHead(nn.Layer):
"""
Split to two transformer header at the last layer.
Cls_layer is used to structure token classification.
Bbox_layer is used to regress bbox coord.
"""
def __init__(self,
in_channels,
out_channels=30,
headers=8,
d_ff=2048,
dropout=0,
max_text_length=500,
point_num=2,
**kwargs):
super(TableMasterHead, self).__init__()
hidden_size = in_channels[-1]
self.layers = clones(
DecoderLayer(headers, hidden_size, dropout, d_ff), 2)
self.cls_layer = clones(
DecoderLayer(headers, hidden_size, dropout, d_ff), 1)
self.bbox_layer = clones(
DecoderLayer(headers, hidden_size, dropout, d_ff), 1)
self.cls_fc = nn.Linear(hidden_size, out_channels)
self.bbox_fc = nn.Sequential(
# nn.Linear(hidden_size, hidden_size),
nn.Linear(hidden_size, point_num * 2),
nn.Sigmoid())
self.norm = nn.LayerNorm(hidden_size)
self.embedding = Embeddings(d_model=hidden_size, vocab=out_channels)
self.positional_encoding = PositionalEncoding(d_model=hidden_size)
self.SOS = out_channels - 3
self.PAD = out_channels - 1
self.out_channels = out_channels
self.point_num = point_num
self.max_text_length = max_text_length
def make_mask(self, tgt):
"""
Make mask for self attention.
:param src: [b, c, h, l_src]
:param tgt: [b, l_tgt]
:return:
"""
trg_pad_mask = (tgt != self.PAD).unsqueeze(1).unsqueeze(3)
tgt_len = paddle.shape(tgt)[1]
trg_sub_mask = paddle.tril(
paddle.ones(
([tgt_len, tgt_len]), dtype=paddle.float32))
tgt_mask = paddle.logical_and(
trg_pad_mask.astype(paddle.float32), trg_sub_mask)
return tgt_mask.astype(paddle.float32)
def decode(self, input, feature, src_mask, tgt_mask):
# main process of transformer decoder.
x = self.embedding(input) # x: 1*x*512, feature: 1*3600,512
x = self.positional_encoding(x)
# origin transformer layers
for i, layer in enumerate(self.layers):
x = layer(x, feature, src_mask, tgt_mask)
# cls head
for layer in self.cls_layer:
cls_x = layer(x, feature, src_mask, tgt_mask)
cls_x = self.norm(cls_x)
# bbox head
for layer in self.bbox_layer:
bbox_x = layer(x, feature, src_mask, tgt_mask)
bbox_x = self.norm(bbox_x)
return self.cls_fc(cls_x), self.bbox_fc(bbox_x)
def greedy_forward(self, SOS, feature):
input = SOS
output = paddle.zeros(
[input.shape[0], self.max_text_length + 1, self.out_channels])
bbox_output = paddle.zeros(
[input.shape[0], self.max_text_length + 1, self.point_num * 2])
max_text_length = paddle.to_tensor(self.max_text_length)
for i in range(max_text_length + 1):
target_mask = self.make_mask(input)
out_step, bbox_output_step = self.decode(input, feature, None,
target_mask)
prob = F.softmax(out_step, axis=-1)
next_word = prob.argmax(axis=2, dtype="int64")
input = paddle.concat(
[input, next_word[:, -1].unsqueeze(-1)], axis=1)
if i == self.max_text_length:
output = out_step
bbox_output = bbox_output_step
return output, bbox_output
def forward_train(self, out_enc, targets):
# x is token of label
# feat is feature after backbone before pe.
# out_enc is feature after pe.
padded_targets = targets[0]
src_mask = None
tgt_mask = self.make_mask(padded_targets[:, :-1])
output, bbox_output = self.decode(padded_targets[:, :-1], out_enc,
src_mask, tgt_mask)
return {'structure_probs': output, 'loc_preds': bbox_output}
def forward_test(self, out_enc):
batch_size = out_enc.shape[0]
SOS = paddle.zeros([batch_size, 1], dtype='int64') + self.SOS
output, bbox_output = self.greedy_forward(SOS, out_enc)
output = F.softmax(output)
return {'structure_probs': output, 'loc_preds': bbox_output}
def forward(self, feat, targets=None):
feat = feat[-1]
b, c, h, w = feat.shape
feat = feat.reshape([b, c, h * w]) # flatten 2D feature map
feat = feat.transpose((0, 2, 1))
out_enc = self.positional_encoding(feat)
if self.training:
return self.forward_train(out_enc, targets)
return self.forward_test(out_enc)
class DecoderLayer(nn.Layer):
"""
Decoder is made of self attention, srouce attention and feed forward.
"""
def __init__(self, headers, d_model, dropout, d_ff):
super(DecoderLayer, self).__init__()
self.self_attn = MultiHeadAttention(headers, d_model, dropout)
self.src_attn = MultiHeadAttention(headers, d_model, dropout)
self.feed_forward = FeedForward(d_model, d_ff, dropout)
self.sublayer = clones(SubLayerConnection(d_model, dropout), 3)
def forward(self, x, feature, src_mask, tgt_mask):
x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, tgt_mask))
x = self.sublayer[1](
x, lambda x: self.src_attn(x, feature, feature, src_mask))
return self.sublayer[2](x, self.feed_forward)
class MultiHeadAttention(nn.Layer):
def __init__(self, headers, d_model, dropout):
super(MultiHeadAttention, self).__init__()
assert d_model % headers == 0
self.d_k = int(d_model / headers)
self.headers = headers
self.linears = clones(nn.Linear(d_model, d_model), 4)
self.attn = None
self.dropout = nn.Dropout(dropout)
def forward(self, query, key, value, mask=None):
B = query.shape[0]
# 1) Do all the linear projections in batch from d_model => h x d_k
query, key, value = \
[l(x).reshape([B, 0, self.headers, self.d_k]).transpose([0, 2, 1, 3])
for l, x in zip(self.linears, (query, key, value))]
# 2) Apply attention on all the projected vectors in batch
x, self.attn = self_attention(
query, key, value, mask=mask, dropout=self.dropout)
x = x.transpose([0, 2, 1, 3]).reshape([B, 0, self.headers * self.d_k])
return self.linears[-1](x)
class FeedForward(nn.Layer):
def __init__(self, d_model, d_ff, dropout):
super(FeedForward, self).__init__()
self.w_1 = nn.Linear(d_model, d_ff)
self.w_2 = nn.Linear(d_ff, d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
return self.w_2(self.dropout(F.relu(self.w_1(x))))
class SubLayerConnection(nn.Layer):
"""
A residual connection followed by a layer norm.
Note for code simplicity the norm is first as opposed to last.
"""
def __init__(self, size, dropout):
super(SubLayerConnection, self).__init__()
self.norm = nn.LayerNorm(size)
self.dropout = nn.Dropout(dropout)
def forward(self, x, sublayer):
return x + self.dropout(sublayer(self.norm(x)))
def masked_fill(x, mask, value):
mask = mask.astype(x.dtype)
return x * paddle.logical_not(mask).astype(x.dtype) + mask * value
def self_attention(query, key, value, mask=None, dropout=None):
"""
Compute 'Scale Dot Product Attention'
"""
d_k = value.shape[-1]
score = paddle.matmul(query, key.transpose([0, 1, 3, 2]) / math.sqrt(d_k))
if mask is not None:
# score = score.masked_fill(mask == 0, -1e9) # b, h, L, L
score = masked_fill(score, mask == 0, -6.55e4) # for fp16
p_attn = F.softmax(score, axis=-1)
if dropout is not None:
p_attn = dropout(p_attn)
return paddle.matmul(p_attn, value), p_attn
def clones(module, N):
""" Produce N identical layers """
return nn.LayerList([copy.deepcopy(module) for _ in range(N)])
class Embeddings(nn.Layer):
def __init__(self, d_model, vocab):
super(Embeddings, self).__init__()
self.lut = nn.Embedding(vocab, d_model)
self.d_model = d_model
def forward(self, *input):
x = input[0]
return self.lut(x) * math.sqrt(self.d_model)
class PositionalEncoding(nn.Layer):
""" Implement the PE function. """
def __init__(self, d_model, dropout=0., max_len=5000):
super(PositionalEncoding, self).__init__()
self.dropout = nn.Dropout(p=dropout)
# Compute the positional encodings once in log space.
pe = paddle.zeros([max_len, d_model])
position = paddle.arange(0, max_len).unsqueeze(1).astype('float32')
div_term = paddle.exp(
paddle.arange(0, d_model, 2) * -math.log(10000.0) / d_model)
pe[:, 0::2] = paddle.sin(position * div_term)
pe[:, 1::2] = paddle.cos(position * div_term)
pe = pe.unsqueeze(0)
self.register_buffer('pe', pe)
def forward(self, feat, **kwargs):
feat = feat + self.pe[:, :paddle.shape(feat)[1]] # pe 1*5000*512
return self.dropout(feat)
......@@ -343,3 +343,46 @@ class DecayLearningRate(object):
power=self.factor,
end_lr=self.end_lr)
return learning_rate
class MultiStepDecay(object):
"""
Piecewise learning rate decay
Args:
step_each_epoch(int): steps each epoch
learning_rate (float): The initial learning rate. It is a python float number.
step_size (int): the interval to update.
gamma (float, optional): The Ratio that the learning rate will be reduced. ``new_lr = origin_lr * gamma`` .
It should be less than 1.0. Default: 0.1.
last_epoch (int, optional): The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
"""
def __init__(self,
learning_rate,
milestones,
step_each_epoch,
gamma,
warmup_epoch=0,
last_epoch=-1,
**kwargs):
super(MultiStepDecay, self).__init__()
self.milestones = [step_each_epoch * e for e in milestones]
self.learning_rate = learning_rate
self.gamma = gamma
self.last_epoch = last_epoch
self.warmup_epoch = round(warmup_epoch * step_each_epoch)
def __call__(self):
learning_rate = lr.MultiStepDecay(
learning_rate=self.learning_rate,
milestones=self.milestones,
gamma=self.gamma,
last_epoch=self.last_epoch)
if self.warmup_epoch > 0:
learning_rate = lr.LinearWarmup(
learning_rate=learning_rate,
warmup_steps=self.warmup_epoch,
start_lr=0.0,
end_lr=self.learning_rate,
last_epoch=self.last_epoch)
return learning_rate
......@@ -26,12 +26,13 @@ from .east_postprocess import EASTPostProcess
from .sast_postprocess import SASTPostProcess
from .fce_postprocess import FCEPostProcess
from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode, \
DistillationCTCLabelDecode, TableLabelDecode, NRTRLabelDecode, SARLabelDecode, \
DistillationCTCLabelDecode, NRTRLabelDecode, SARLabelDecode, \
SEEDLabelDecode, PRENLabelDecode, ViTSTRLabelDecode, ABINetLabelDecode
from .cls_postprocess import ClsPostProcess
from .pg_postprocess import PGPostProcess
from .vqa_token_ser_layoutlm_postprocess import VQASerTokenLayoutLMPostProcess
from .vqa_token_re_layoutlm_postprocess import VQAReTokenLayoutLMPostProcess
from .table_postprocess import TableMasterLabelDecode, TableLabelDecode
def build_post_process(config, global_config=None):
......@@ -42,7 +43,8 @@ def build_post_process(config, global_config=None):
'DistillationDBPostProcess', 'NRTRLabelDecode', 'SARLabelDecode',
'SEEDLabelDecode', 'VQASerTokenLayoutLMPostProcess',
'VQAReTokenLayoutLMPostProcess', 'PRENLabelDecode',
'DistillationSARLabelDecode', 'ViTSTRLabelDecode', 'ABINetLabelDecode'
'DistillationSARLabelDecode', 'ViTSTRLabelDecode', 'ABINetLabelDecode',
'TableMasterLabelDecode'
]
if config['name'] == 'PSEPostProcess':
......
......@@ -380,146 +380,6 @@ class SRNLabelDecode(BaseRecLabelDecode):
return idx
class TableLabelDecode(object):
""" """
def __init__(self, character_dict_path, **kwargs):
list_character, list_elem = self.load_char_elem_dict(
character_dict_path)
list_character = self.add_special_char(list_character)
list_elem = self.add_special_char(list_elem)
self.dict_character = {}
self.dict_idx_character = {}
for i, char in enumerate(list_character):
self.dict_idx_character[i] = char
self.dict_character[char] = i
self.dict_elem = {}
self.dict_idx_elem = {}
for i, elem in enumerate(list_elem):
self.dict_idx_elem[i] = elem
self.dict_elem[elem] = i
def load_char_elem_dict(self, character_dict_path):
list_character = []
list_elem = []
with open(character_dict_path, "rb") as fin:
lines = fin.readlines()
substr = lines[0].decode('utf-8').strip("\n").strip("\r\n").split(
"\t")
character_num = int(substr[0])
elem_num = int(substr[1])
for cno in range(1, 1 + character_num):
character = lines[cno].decode('utf-8').strip("\n").strip("\r\n")
list_character.append(character)
for eno in range(1 + character_num, 1 + character_num + elem_num):
elem = lines[eno].decode('utf-8').strip("\n").strip("\r\n")
list_elem.append(elem)
return list_character, list_elem
def add_special_char(self, list_character):
self.beg_str = "sos"
self.end_str = "eos"
list_character = [self.beg_str] + list_character + [self.end_str]
return list_character
def __call__(self, preds):
structure_probs = preds['structure_probs']
loc_preds = preds['loc_preds']
if isinstance(structure_probs, paddle.Tensor):
structure_probs = structure_probs.numpy()
if isinstance(loc_preds, paddle.Tensor):
loc_preds = loc_preds.numpy()
structure_idx = structure_probs.argmax(axis=2)
structure_probs = structure_probs.max(axis=2)
structure_str, structure_pos, result_score_list, result_elem_idx_list = self.decode(
structure_idx, structure_probs, 'elem')
res_html_code_list = []
res_loc_list = []
batch_num = len(structure_str)
for bno in range(batch_num):
res_loc = []
for sno in range(len(structure_str[bno])):
text = structure_str[bno][sno]
if text in ['<td>', '<td']:
pos = structure_pos[bno][sno]
res_loc.append(loc_preds[bno, pos])
res_html_code = ''.join(structure_str[bno])
res_loc = np.array(res_loc)
res_html_code_list.append(res_html_code)
res_loc_list.append(res_loc)
return {
'res_html_code': res_html_code_list,
'res_loc': res_loc_list,
'res_score_list': result_score_list,
'res_elem_idx_list': result_elem_idx_list,
'structure_str_list': structure_str
}
def decode(self, text_index, structure_probs, char_or_elem):
"""convert text-label into text-index.
"""
if char_or_elem == "char":
current_dict = self.dict_idx_character
else:
current_dict = self.dict_idx_elem
ignored_tokens = self.get_ignored_tokens('elem')
beg_idx, end_idx = ignored_tokens
result_list = []
result_pos_list = []
result_score_list = []
result_elem_idx_list = []
batch_size = len(text_index)
for batch_idx in range(batch_size):
char_list = []
elem_pos_list = []
elem_idx_list = []
score_list = []
for idx in range(len(text_index[batch_idx])):
tmp_elem_idx = int(text_index[batch_idx][idx])
if idx > 0 and tmp_elem_idx == end_idx:
break
if tmp_elem_idx in ignored_tokens:
continue
char_list.append(current_dict[tmp_elem_idx])
elem_pos_list.append(idx)
score_list.append(structure_probs[batch_idx, idx])
elem_idx_list.append(tmp_elem_idx)
result_list.append(char_list)
result_pos_list.append(elem_pos_list)
result_score_list.append(score_list)
result_elem_idx_list.append(elem_idx_list)
return result_list, result_pos_list, result_score_list, result_elem_idx_list
def get_ignored_tokens(self, char_or_elem):
beg_idx = self.get_beg_end_flag_idx("beg", char_or_elem)
end_idx = self.get_beg_end_flag_idx("end", char_or_elem)
return [beg_idx, end_idx]
def get_beg_end_flag_idx(self, beg_or_end, char_or_elem):
if char_or_elem == "char":
if beg_or_end == "beg":
idx = self.dict_character[self.beg_str]
elif beg_or_end == "end":
idx = self.dict_character[self.end_str]
else:
assert False, "Unsupport type %s in get_beg_end_flag_idx of char" \
% beg_or_end
elif char_or_elem == "elem":
if beg_or_end == "beg":
idx = self.dict_elem[self.beg_str]
elif beg_or_end == "end":
idx = self.dict_elem[self.end_str]
else:
assert False, "Unsupport type %s in get_beg_end_flag_idx of elem" \
% beg_or_end
else:
assert False, "Unsupport type %s in char_or_elem" \
% char_or_elem
return idx
class SARLabelDecode(BaseRecLabelDecode):
""" Convert between text-label and text-index """
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import paddle
from .rec_postprocess import AttnLabelDecode
class TableLabelDecode(AttnLabelDecode):
""" """
def __init__(self, character_dict_path, **kwargs):
super(TableLabelDecode, self).__init__(character_dict_path)
self.td_token = ['<td>', '<td', '<eb></eb>', '<td></td>']
def __call__(self, preds, batch=None):
structure_probs = preds['structure_probs']
bbox_preds = preds['loc_preds']
if isinstance(structure_probs, paddle.Tensor):
structure_probs = structure_probs.numpy()
if isinstance(bbox_preds, paddle.Tensor):
bbox_preds = bbox_preds.numpy()
shape_list = batch[-1]
result = self.decode(structure_probs, bbox_preds, shape_list)
if len(batch) == 1: # only contains shape
return result
label_decode_result = self.decode_label(batch)
return result, label_decode_result
def decode(self, structure_probs, bbox_preds, shape_list):
"""convert text-label into text-index.
"""
ignored_tokens = self.get_ignored_tokens()
end_idx = self.dict[self.end_str]
structure_idx = structure_probs.argmax(axis=2)
structure_probs = structure_probs.max(axis=2)
structure_batch_list = []
bbox_batch_list = []
batch_size = len(structure_idx)
for batch_idx in range(batch_size):
structure_list = []
bbox_list = []
score_list = []
for idx in range(len(structure_idx[batch_idx])):
char_idx = int(structure_idx[batch_idx][idx])
if idx > 0 and char_idx == end_idx:
break
if char_idx in ignored_tokens:
continue
text = self.character[char_idx]
if text in self.td_token:
bbox = bbox_preds[batch_idx, idx]
bbox = self._bbox_decode(bbox, shape_list[batch_idx])
bbox_list.append(bbox)
structure_list.append(text)
score_list.append(structure_probs[batch_idx, idx])
structure_batch_list.append([structure_list, np.mean(score_list)])
bbox_batch_list.append(np.array(bbox_list))
result = {
'bbox_batch_list': bbox_batch_list,
'structure_batch_list': structure_batch_list,
}
return result
def decode_label(self, batch):
"""convert text-label into text-index.
"""
structure_idx = batch[1]
gt_bbox_list = batch[2]
shape_list = batch[-1]
ignored_tokens = self.get_ignored_tokens()
end_idx = self.dict[self.end_str]
structure_batch_list = []
bbox_batch_list = []
batch_size = len(structure_idx)
for batch_idx in range(batch_size):
structure_list = []
bbox_list = []
for idx in range(len(structure_idx[batch_idx])):
char_idx = int(structure_idx[batch_idx][idx])
if idx > 0 and char_idx == end_idx:
break
if char_idx in ignored_tokens:
continue
structure_list.append(self.character[char_idx])
bbox = gt_bbox_list[batch_idx][idx]
if bbox.sum() != 0:
bbox = self._bbox_decode(bbox, shape_list[batch_idx])
bbox_list.append(bbox)
structure_batch_list.append(structure_list)
bbox_batch_list.append(bbox_list)
result = {
'bbox_batch_list': bbox_batch_list,
'structure_batch_list': structure_batch_list,
}
return result
def _bbox_decode(self, bbox, shape):
h, w, ratio_h, ratio_w, pad_h, pad_w = shape
src_h = h / ratio_h
src_w = w / ratio_w
bbox[0::2] *= src_w
bbox[1::2] *= src_h
return bbox
class TableMasterLabelDecode(TableLabelDecode):
""" """
def __init__(self, character_dict_path, box_shape='ori', **kwargs):
super(TableMasterLabelDecode, self).__init__(character_dict_path)
self.box_shape = box_shape
assert box_shape in [
'ori', 'pad'
], 'The shape used for box normalization must be ori or pad'
def add_special_char(self, dict_character):
self.beg_str = '<SOS>'
self.end_str = '<EOS>'
self.unknown_str = '<UKN>'
self.pad_str = '<PAD>'
dict_character = dict_character
dict_character = dict_character + [
self.unknown_str, self.beg_str, self.end_str, self.pad_str
]
return dict_character
def get_ignored_tokens(self):
pad_idx = self.dict[self.pad_str]
start_idx = self.dict[self.beg_str]
end_idx = self.dict[self.end_str]
unknown_idx = self.dict[self.unknown_str]
return [start_idx, end_idx, pad_idx, unknown_idx]
def _bbox_decode(self, bbox, shape):
h, w, ratio_h, ratio_w, pad_h, pad_w = shape
if self.box_shape == 'pad':
h, w = pad_h, pad_w
bbox[0::2] *= w
bbox[1::2] *= h
bbox[0::2] /= ratio_w
bbox[1::2] /= ratio_h
return bbox
<thead>
<tr>
<td></td>
</tr>
</thead>
<tbody>
<eb></eb>
</tbody>
<td
colspan="5"
>
</td>
colspan="2"
colspan="3"
<eb2></eb2>
<eb1></eb1>
rowspan="2"
colspan="4"
colspan="6"
rowspan="3"
colspan="9"
colspan="10"
colspan="7"
rowspan="4"
rowspan="5"
rowspan="9"
colspan="8"
rowspan="8"
rowspan="6"
rowspan="7"
rowspan="10"
<eb3></eb3>
<eb4></eb4>
<eb5></eb5>
<eb6></eb6>
<eb7></eb7>
<eb8></eb8>
<eb9></eb9>
<eb10></eb10>
......@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import cv2
import os
import numpy as np
from PIL import Image, ImageDraw, ImageFont
......@@ -110,3 +111,16 @@ def draw_re_results(image,
img_new = Image.blend(image, img_new, 0.5)
return np.array(img_new)
def draw_rectangle(img_path, boxes, use_xywh=False):
img = cv2.imread(img_path)
img_show = img.copy()
for box in boxes.astype(int):
if use_xywh:
x, y, w, h = box
x1, y1, x2, y2 = x - w // 2, y - h // 2, x + w // 2, y + h // 2
else:
x1, y1, x2, y2 = box
cv2.rectangle(img_show, (x1, y1), (x2, y2), (255, 0, 0), 2)
return img_show
\ No newline at end of file
......@@ -35,7 +35,7 @@
|模型名称|模型简介|推理模型大小|下载地址|
| --- | --- | --- | --- |
|en_ppocr_mobile_v2.0_table_structure|PubLayNet数据集训练的英文表格场景的表格结构预测|18.6M|[推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_structure_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.1/table/en_ppocr_mobile_v2.0_table_structure_train.tar) |
|en_ppocr_mobile_v2.0_table_structure|PubTabNet数据集训练的英文表格场景的表格结构预测|18.6M|[推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_structure_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.1/table/en_ppocr_mobile_v2.0_table_structure_train.tar) |
<a name="3"></a>
## 3. VQA模型
......
......@@ -23,43 +23,63 @@ os.environ["FLAGS_allocator_strategy"] = 'auto_growth'
import cv2
import numpy as np
import time
import json
import tools.infer.utility as utility
from ppocr.data import create_operators, transform
from ppocr.postprocess import build_post_process
from ppocr.utils.logging import get_logger
from ppocr.utils.utility import get_image_file_list, check_and_read_gif
from ppocr.utils.visual import draw_rectangle
from ppstructure.utility import parse_args
logger = get_logger()
def build_pre_process_list(args):
resize_op = {'ResizeTableImage': {'max_len': args.table_max_len, }}
pad_op = {
'PaddingTableImage': {
'size': [args.table_max_len, args.table_max_len]
}
}
normalize_op = {
'NormalizeImage': {
'std': [0.229, 0.224, 0.225] if
args.table_algorithm not in ['TableMaster'] else [0.5, 0.5, 0.5],
'mean': [0.485, 0.456, 0.406] if
args.table_algorithm not in ['TableMaster'] else [0.5, 0.5, 0.5],
'scale': '1./255.',
'order': 'hwc'
}
}
to_chw_op = {'ToCHWImage': None}
keep_keys_op = {'KeepKeys': {'keep_keys': ['image', 'shape']}}
if args.table_algorithm not in ['TableMaster']:
pre_process_list = [
resize_op, normalize_op, pad_op, to_chw_op, keep_keys_op
]
else:
pre_process_list = [
resize_op, pad_op, normalize_op, to_chw_op, keep_keys_op
]
return pre_process_list
class TableStructurer(object):
def __init__(self, args):
pre_process_list = [{
'ResizeTableImage': {
'max_len': args.table_max_len
}
}, {
'NormalizeImage': {
'std': [0.229, 0.224, 0.225],
'mean': [0.485, 0.456, 0.406],
'scale': '1./255.',
'order': 'hwc'
pre_process_list = build_pre_process_list(args)
if args.table_algorithm not in ['TableMaster']:
postprocess_params = {
'name': 'TableLabelDecode',
"character_dict_path": args.table_char_dict_path,
}
}, {
'PaddingTableImage': None
}, {
'ToCHWImage': None
}, {
'KeepKeys': {
'keep_keys': ['image']
else:
postprocess_params = {
'name': 'TableMasterLabelDecode',
"character_dict_path": args.table_char_dict_path,
'box_shape': 'pad'
}
}]
postprocess_params = {
'name': 'TableLabelDecode',
"character_dict_path": args.table_char_dict_path,
}
self.preprocess_op = create_operators(pre_process_list)
self.postprocess_op = build_post_process(postprocess_params)
......@@ -88,27 +108,17 @@ class TableStructurer(object):
preds['structure_probs'] = outputs[1]
preds['loc_preds'] = outputs[0]
post_result = self.postprocess_op(preds)
structure_str_list = post_result['structure_str_list']
res_loc = post_result['res_loc']
imgh, imgw = ori_im.shape[0:2]
res_loc_final = []
for rno in range(len(res_loc[0])):
x0, y0, x1, y1 = res_loc[0][rno]
left = max(int(imgw * x0), 0)
top = max(int(imgh * y0), 0)
right = min(int(imgw * x1), imgw - 1)
bottom = min(int(imgh * y1), imgh - 1)
res_loc_final.append([left, top, right, bottom])
structure_str_list = structure_str_list[0][:-1]
shape_list = np.expand_dims(data[-1], axis=0)
post_result = self.postprocess_op(preds, [shape_list])
structure_str_list = post_result['structure_batch_list'][0]
bbox_list = post_result['bbox_batch_list'][0]
structure_str_list = structure_str_list[0]
structure_str_list = [
'<html>', '<body>', '<table>'
] + structure_str_list + ['</table>', '</body>', '</html>']
elapse = time.time() - starttime
return (structure_str_list, res_loc_final), elapse
return structure_str_list, bbox_list, elapse
def main(args):
......@@ -116,21 +126,35 @@ def main(args):
table_structurer = TableStructurer(args)
count = 0
total_time = 0
for image_file in image_file_list:
img, flag = check_and_read_gif(image_file)
if not flag:
img = cv2.imread(image_file)
if img is None:
logger.info("error in loading image:{}".format(image_file))
continue
structure_res, elapse = table_structurer(img)
logger.info("result: {}".format(structure_res))
if count > 0:
total_time += elapse
count += 1
logger.info("Predict time of {}: {}".format(image_file, elapse))
use_xywh = args.table_algorithm in ['TableMaster']
os.makedirs(args.output, exist_ok=True)
with open(
os.path.join(args.output, 'infer.txt'), mode='w',
encoding='utf-8') as f_w:
for image_file in image_file_list:
img, flag = check_and_read_gif(image_file)
if not flag:
img = cv2.imread(image_file)
if img is None:
logger.info("error in loading image:{}".format(image_file))
continue
structure_str_list, bbox_list, elapse = table_structurer(img)
bbox_list_str = json.dumps(bbox_list.tolist())
logger.info("result: {}, {}".format(structure_str_list,
bbox_list_str))
f_w.write("result: {}, {}\n".format(structure_str_list,
bbox_list_str))
img = draw_rectangle(image_file, bbox_list, use_xywh)
img_save_path = os.path.join(args.output,
os.path.basename(image_file))
cv2.imwrite(img_save_path, img)
logger.info("save vis result to {}".format(img_save_path))
if count > 0:
total_time += elapse
count += 1
logger.info("Predict time of {}: {}".format(image_file, elapse))
if __name__ == "__main__":
......
......@@ -25,6 +25,7 @@ def init_args():
parser.add_argument("--output", type=str, default='./output')
# params for table structure
parser.add_argument("--table_max_len", type=int, default=488)
parser.add_argument("--table_algorithm", type=str, default='TableAttn')
parser.add_argument("--table_model_dir", type=str)
parser.add_argument(
"--table_char_dict_path",
......
......@@ -126,6 +126,8 @@ def export_single_model(model,
infer_shape[-1] = 100
elif arch_config["model_type"] == "table":
infer_shape = [3, 488, 488]
if arch_config["algorithm"] == "TableMaster":
infer_shape = [3, 480, 480]
model = to_static(
model,
input_spec=[
......
......@@ -36,10 +36,12 @@ from ppocr.modeling.architectures import build_model
from ppocr.postprocess import build_post_process
from ppocr.utils.save_load import load_model
from ppocr.utils.utility import get_image_file_list
from ppocr.utils.visual import draw_rectangle
import tools.program as program
import cv2
@paddle.no_grad()
def main(config, device, logger, vdl_writer):
global_config = config['Global']
......@@ -53,53 +55,61 @@ def main(config, device, logger, vdl_writer):
getattr(post_process_class, 'character'))
model = build_model(config['Architecture'])
algorithm = config['Architecture']['algorithm']
use_xywh = algorithm in ['TableMaster']
load_model(config, model)
# create data ops
transforms = []
use_padding = False
for op in config['Eval']['dataset']['transforms']:
op_name = list(op)[0]
if 'Label' in op_name:
if 'Encode' in op_name:
continue
if op_name == 'KeepKeys':
op[op_name]['keep_keys'] = ['image']
if op_name == "ResizeTableImage":
use_padding = True
padding_max_len = op['ResizeTableImage']['max_len']
op[op_name]['keep_keys'] = ['image', 'shape']
transforms.append(op)
global_config['infer_mode'] = True
ops = create_operators(transforms, global_config)
save_res_path = config['Global']['save_res_path']
os.makedirs(save_res_path, exist_ok=True)
model.eval()
for file in get_image_file_list(config['Global']['infer_img']):
logger.info("infer_img: {}".format(file))
with open(file, 'rb') as f:
img = f.read()
data = {'image': img}
batch = transform(data, ops)
images = np.expand_dims(batch[0], axis=0)
images = paddle.to_tensor(images)
preds = model(images)
post_result = post_process_class(preds)
res_html_code = post_result['res_html_code']
res_loc = post_result['res_loc']
img = cv2.imread(file)
imgh, imgw = img.shape[0:2]
res_loc_final = []
for rno in range(len(res_loc[0])):
x0, y0, x1, y1 = res_loc[0][rno]
left = max(int(imgw * x0), 0)
top = max(int(imgh * y0), 0)
right = min(int(imgw * x1), imgw - 1)
bottom = min(int(imgh * y1), imgh - 1)
cv2.rectangle(img, (left, top), (right, bottom), (0, 0, 255), 2)
res_loc_final.append([left, top, right, bottom])
res_loc_str = json.dumps(res_loc_final)
logger.info("result: {}, {}".format(res_html_code, res_loc_final))
logger.info("success!")
with open(
os.path.join(save_res_path, 'infer.txt'), mode='w',
encoding='utf-8') as f_w:
for file in get_image_file_list(config['Global']['infer_img']):
logger.info("infer_img: {}".format(file))
with open(file, 'rb') as f:
img = f.read()
data = {'image': img}
batch = transform(data, ops)
images = np.expand_dims(batch[0], axis=0)
shape_list = np.expand_dims(batch[1], axis=0)
images = paddle.to_tensor(images)
preds = model(images)
post_result = post_process_class(preds, [shape_list])
structure_str_list = post_result['structure_batch_list'][0]
bbox_list = post_result['bbox_batch_list'][0]
structure_str_list = structure_str_list[0]
structure_str_list = [
'<html>', '<body>', '<table>'
] + structure_str_list + ['</table>', '</body>', '</html>']
bbox_list_str = json.dumps(bbox_list.tolist())
logger.info("result: {}, {}".format(structure_str_list,
bbox_list_str))
f_w.write("result: {}, {}\n".format(structure_str_list,
bbox_list_str))
img = draw_rectangle(file, bbox_list, use_xywh)
cv2.imwrite(
os.path.join(save_res_path, os.path.basename(file)), img)
logger.info("success!")
if __name__ == '__main__':
......
......@@ -281,8 +281,11 @@ def train(config,
if cal_metric_during_train and epoch % calc_epoch_interval == 0: # only rec and cls need
batch = [item.numpy() for item in batch]
if model_type in ['table', 'kie']:
if model_type in ['kie']:
eval_class(preds, batch)
elif model_type in ['table']:
post_result = post_process_class(preds, batch)
eval_class(post_result, batch)
else:
if config['Loss']['name'] in ['MultiLoss', 'MultiLoss_v2'
]: # for multi head loss
......@@ -463,7 +466,6 @@ def eval(model,
preds = model(batch)
else:
preds = model(images)
batch_numpy = []
for item in batch:
if isinstance(item, paddle.Tensor):
......@@ -473,9 +475,9 @@ def eval(model,
# Obtain usable results from post-processing methods
total_time += time.time() - start
# Evaluate the results of the current batch
if model_type in ['table', 'kie']:
if model_type in ['kie']:
eval_class(preds, batch_numpy)
elif model_type in ['vqa']:
elif model_type in ['table', 'vqa']:
post_result = post_process_class(preds, batch_numpy)
eval_class(post_result, batch_numpy)
else:
......@@ -577,7 +579,7 @@ def preprocess(is_train=False):
'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN',
'CLS', 'PGNet', 'Distillation', 'NRTR', 'TableAttn', 'SAR', 'PSE',
'SEED', 'SDMGR', 'LayoutXLM', 'LayoutLM', 'LayoutLMv2', 'PREN', 'FCE',
'SVTR', 'ViTSTR', 'ABINet', 'DB++'
'SVTR', 'ViTSTR', 'ABINet', 'DB++', 'TableMaster'
]
if use_xpu:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册