提交 fd703c71 编写于 作者: A andyjpaddle

Merge branch 'dygraph' of https://github.com/PaddlePaddle/PaddleOCR into dygraph

...@@ -72,6 +72,7 @@ PaddleOCR support a variety of cutting-edge algorithms related to OCR, and devel ...@@ -72,6 +72,7 @@ PaddleOCR support a variety of cutting-edge algorithms related to OCR, and devel
<img src="https://raw.githubusercontent.com/PaddlePaddle/PaddleOCR/dygraph/doc/joinus.PNG" width = "200" height = "200" /> <img src="https://raw.githubusercontent.com/PaddlePaddle/PaddleOCR/dygraph/doc/joinus.PNG" width = "200" height = "200" />
</div> </div>
<a name="Supported-Chinese-model-list"></a> <a name="Supported-Chinese-model-list"></a>
## PP-OCR Series Model List(Update on September 8th) ## PP-OCR Series Model List(Update on September 8th)
| Model introduction | Model name | Recommended scene | Detection model | Direction classifier | Recognition model | | Model introduction | Model name | Recommended scene | Detection model | Direction classifier | Recognition model |
......
...@@ -65,7 +65,7 @@ Loss: ...@@ -65,7 +65,7 @@ Loss:
- ["Student", "Teacher"] - ["Student", "Teacher"]
maps_name: "thrink_maps" maps_name: "thrink_maps"
weight: 1.0 weight: 1.0
act: "softmax" # act: None
model_name_pairs: ["Student", "Teacher"] model_name_pairs: ["Student", "Teacher"]
key: maps key: maps
- DistillationDBLoss: - DistillationDBLoss:
......
...@@ -60,7 +60,7 @@ Loss: ...@@ -60,7 +60,7 @@ Loss:
- ["Student", "Student2"] - ["Student", "Student2"]
maps_name: "thrink_maps" maps_name: "thrink_maps"
weight: 1.0 weight: 1.0
act: "softmax" # act: None
model_name_pairs: ["Student", "Student2"] model_name_pairs: ["Student", "Student2"]
key: maps key: maps
- DistillationDBLoss: - DistillationDBLoss:
......
Global:
use_gpu: true
epoch_num: 17
log_smooth_window: 20
print_batch_step: 100
save_model_dir: ./output/table_master/
save_epoch_step: 17
eval_batch_step: [0, 6259]
cal_metric_during_train: true
pretrained_model: null
checkpoints:
save_inference_dir: output/table_master/infer
use_visualdl: false
infer_img: ppstructure/docs/table/table.jpg
save_res_path: ./output/table_master
character_dict_path: ppocr/utils/dict/table_master_structure_dict.txt
infer_mode: false
max_text_length: 500
process_total_num: 0
process_cut_num: 0
Optimizer:
name: Adam
beta1: 0.9
beta2: 0.999
lr:
name: MultiStepDecay
learning_rate: 0.001
milestones: [12, 15]
gamma: 0.1
warmup_epoch: 0.02
regularizer:
name: L2
factor: 0.0
Architecture:
model_type: table
algorithm: TableMaster
Backbone:
name: TableResNetExtra
gcb_config:
ratio: 0.0625
headers: 1
att_scale: False
fusion_type: channel_add
layers: [False, True, True, True]
layers: [1,2,5,3]
Head:
name: TableMasterHead
hidden_size: 512
headers: 8
dropout: 0
d_ff: 2024
max_text_length: 500
Loss:
name: TableMasterLoss
ignore_index: 42 # set to len of dict + 3
PostProcess:
name: TableMasterLabelDecode
box_shape: pad
Metric:
name: TableMetric
main_indicator: acc
compute_bbox_metric: False
Train:
dataset:
name: PubTabDataSet
data_dir: train_data/table/pubtabnet/train/
label_file_list: [train_data/table/pubtabnet/PubTabNet_2.0.0_train.jsonl]
transforms:
- DecodeImage:
img_mode: BGR
channel_first: False
- TableMasterLabelEncode:
learn_empty_box: False
merge_no_span_structure: True
replace_empty_cell_token: True
- ResizeTableImage:
max_len: 480
resize_bboxes: True
- PaddingTableImage:
size: [480, 480]
- TableBoxEncode:
use_xywh: True
- NormalizeImage:
scale: 1./255.
mean: [0.5, 0.5, 0.5]
std: [0.5, 0.5, 0.5]
order: hwc
- ToCHWImage: null
- KeepKeys:
keep_keys: [image, structure, bboxes, bbox_masks, shape]
loader:
shuffle: True
batch_size_per_card: 10
drop_last: True
num_workers: 8
Eval:
dataset:
name: PubTabDataSet
data_dir: train_data/table/pubtabnet/train/
label_file_list: [train_data/table/pubtabnet/PubTabNet_2.0.0_val.jsonl]
transforms:
- DecodeImage:
img_mode: BGR
channel_first: False
- TableMasterLabelEncode:
learn_empty_box: False
merge_no_span_structure: True
replace_empty_cell_token: True
- ResizeTableImage:
max_len: 480
resize_bboxes: True
- PaddingTableImage:
size: [480, 480]
- TableBoxEncode:
use_xywh: True
- NormalizeImage:
scale: 1./255.
mean: [0.5, 0.5, 0.5]
std: [0.5, 0.5, 0.5]
order: hwc
- ToCHWImage: null
- KeepKeys:
keep_keys: [image, structure, bboxes, bbox_masks, shape]
loader:
shuffle: False
drop_last: False
batch_size_per_card: 10
num_workers: 8
\ No newline at end of file
...@@ -4,21 +4,20 @@ Global: ...@@ -4,21 +4,20 @@ Global:
log_smooth_window: 20 log_smooth_window: 20
print_batch_step: 5 print_batch_step: 5
save_model_dir: ./output/table_mv3/ save_model_dir: ./output/table_mv3/
save_epoch_step: 3 save_epoch_step: 400
# evaluation is run every 400 iterations after the 0th iteration # evaluation is run every 400 iterations after the 0th iteration
eval_batch_step: [0, 400] eval_batch_step: [0, 400]
cal_metric_during_train: True cal_metric_during_train: True
pretrained_model: pretrained_model:
checkpoints: checkpoints:
save_inference_dir: save_inference_dir:
use_visualdl: False use_visualdl: False
infer_img: doc/table/table.jpg infer_img: ppstructure/docs/table/table.jpg
save_res_path: output/table_mv3
# for data or label process # for data or label process
character_dict_path: ppocr/utils/dict/table_structure_dict.txt character_dict_path: ppocr/utils/dict/table_structure_dict.txt
character_type: en character_type: en
max_text_length: 100 max_text_length: 800
max_elem_length: 800
max_cell_num: 500
infer_mode: False infer_mode: False
process_total_num: 0 process_total_num: 0
process_cut_num: 0 process_cut_num: 0
...@@ -44,11 +43,8 @@ Architecture: ...@@ -44,11 +43,8 @@ Architecture:
Head: Head:
name: TableAttentionHead name: TableAttentionHead
hidden_size: 256 hidden_size: 256
l2_decay: 0.00001
loc_type: 2 loc_type: 2
max_text_length: 100 max_text_length: 800
max_elem_length: 800
max_cell_num: 500
Loss: Loss:
name: TableAttentionLoss name: TableAttentionLoss
...@@ -61,28 +57,34 @@ PostProcess: ...@@ -61,28 +57,34 @@ PostProcess:
Metric: Metric:
name: TableMetric name: TableMetric
main_indicator: acc main_indicator: acc
compute_bbox_metric: false # cost many time, set False for training
Train: Train:
dataset: dataset:
name: PubTabDataSet name: PubTabDataSet
data_dir: train_data/table/pubtabnet/train/ data_dir: train_data/table/pubtabnet/train/
label_file_path: train_data/table/pubtabnet/PubTabNet_2.0.0_train.jsonl label_file_list: [train_data/table/pubtabnet/PubTabNet_2.0.0_train.jsonl]
transforms: transforms:
- DecodeImage: # load image - DecodeImage: # load image
img_mode: BGR img_mode: BGR
channel_first: False channel_first: False
- TableLabelEncode:
learn_empty_box: False
merge_no_span_structure: False
replace_empty_cell_token: False
- TableBoxEncode:
- ResizeTableImage: - ResizeTableImage:
max_len: 488 max_len: 488
- TableLabelEncode:
- NormalizeImage: - NormalizeImage:
scale: 1./255. scale: 1./255.
mean: [0.485, 0.456, 0.406] mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225] std: [0.229, 0.224, 0.225]
order: 'hwc' order: 'hwc'
- PaddingTableImage: - PaddingTableImage:
size: [488, 488]
- ToCHWImage: - ToCHWImage:
- KeepKeys: - KeepKeys:
keep_keys: ['image', 'structure', 'bbox_list', 'sp_tokens', 'bbox_list_mask'] keep_keys: [ 'image', 'structure', 'bboxes', 'bbox_masks', 'shape' ]
loader: loader:
shuffle: True shuffle: True
batch_size_per_card: 32 batch_size_per_card: 32
...@@ -92,24 +94,29 @@ Train: ...@@ -92,24 +94,29 @@ Train:
Eval: Eval:
dataset: dataset:
name: PubTabDataSet name: PubTabDataSet
data_dir: train_data/table/pubtabnet/val/ data_dir: /home/zhoujun20/table/PubTabNe/pubtabnet/val/
label_file_path: train_data/table/pubtabnet/PubTabNet_2.0.0_val.jsonl label_file_list: [/home/zhoujun20/table/PubTabNe/pubtabnet/val_500.jsonl]
transforms: transforms:
- DecodeImage: # load image - DecodeImage: # load image
img_mode: BGR img_mode: BGR
channel_first: False channel_first: False
- TableLabelEncode:
learn_empty_box: False
merge_no_span_structure: False
replace_empty_cell_token: False
- TableBoxEncode:
- ResizeTableImage: - ResizeTableImage:
max_len: 488 max_len: 488
- TableLabelEncode:
- NormalizeImage: - NormalizeImage:
scale: 1./255. scale: 1./255.
mean: [0.485, 0.456, 0.406] mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225] std: [0.229, 0.224, 0.225]
order: 'hwc' order: 'hwc'
- PaddingTableImage: - PaddingTableImage:
size: [488, 488]
- ToCHWImage: - ToCHWImage:
- KeepKeys: - KeepKeys:
keep_keys: ['image', 'structure', 'bbox_list', 'sp_tokens', 'bbox_list_mask'] keep_keys: [ 'image', 'structure', 'bboxes', 'bbox_masks', 'shape' ]
loader: loader:
shuffle: False shuffle: False
drop_last: False drop_last: False
......
Global:
use_gpu: True
epoch_num: &epoch_num 200
log_smooth_window: 10
print_batch_step: 10
save_model_dir: ./output/re_layoutlmv2_funsd
save_epoch_step: 2000
# evaluation is run every 10 iterations after the 0th iteration
eval_batch_step: [ 0, 57 ]
cal_metric_during_train: False
save_inference_dir:
use_visualdl: False
seed: 2022
infer_img: train_data/FUNSD/testing_data/images/83624198.png
save_res_path: ./output/re_layoutlmv2_funsd/res/
Architecture:
model_type: vqa
algorithm: &algorithm "LayoutLMv2"
Transform:
Backbone:
name: LayoutLMv2ForRe
pretrained: True
checkpoints:
Loss:
name: LossFromOutput
key: loss
reduction: mean
Optimizer:
name: AdamW
beta1: 0.9
beta2: 0.999
clip_norm: 10
lr:
learning_rate: 0.00005
warmup_epoch: 10
regularizer:
name: L2
factor: 0.00000
PostProcess:
name: VQAReTokenLayoutLMPostProcess
Metric:
name: VQAReTokenMetric
main_indicator: hmean
Train:
dataset:
name: SimpleDataSet
data_dir: ./train_data/FUNSD/training_data/images/
label_file_list:
- ./train_data/FUNSD/train.json
ratio_list: [ 1.0 ]
transforms:
- DecodeImage: # load image
img_mode: RGB
channel_first: False
- VQATokenLabelEncode: # Class handling label
contains_re: True
algorithm: *algorithm
class_path: &class_path train_data/FUNSD/class_list.txt
- VQATokenPad:
max_seq_len: &max_seq_len 512
return_attention_mask: True
- VQAReTokenRelation:
- VQAReTokenChunk:
max_seq_len: *max_seq_len
- Resize:
size: [224,224]
- NormalizeImage:
scale: 1./255.
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: 'hwc'
- ToCHWImage:
- KeepKeys:
# dataloader will return list in this order
keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image', 'entities', 'relations']
loader:
shuffle: True
drop_last: False
batch_size_per_card: 8
num_workers: 8
collate_fn: ListCollator
Eval:
dataset:
name: SimpleDataSet
data_dir: ./train_data/FUNSD/testing_data/images/
label_file_list:
- ./train_data/FUNSD/test.json
transforms:
- DecodeImage: # load image
img_mode: RGB
channel_first: False
- VQATokenLabelEncode: # Class handling label
contains_re: True
algorithm: *algorithm
class_path: *class_path
- VQATokenPad:
max_seq_len: *max_seq_len
return_attention_mask: True
- VQAReTokenRelation:
- VQAReTokenChunk:
max_seq_len: *max_seq_len
- Resize:
size: [224,224]
- NormalizeImage:
scale: 1./255.
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: 'hwc'
- ToCHWImage:
- KeepKeys:
# dataloader will return list in this order
keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image', 'entities', 'relations']
loader:
shuffle: False
drop_last: False
batch_size_per_card: 8
num_workers: 8
collate_fn: ListCollator
...@@ -3,16 +3,16 @@ Global: ...@@ -3,16 +3,16 @@ Global:
epoch_num: &epoch_num 200 epoch_num: &epoch_num 200
log_smooth_window: 10 log_smooth_window: 10
print_batch_step: 10 print_batch_step: 10
save_model_dir: ./output/re_layoutlmv2/ save_model_dir: ./output/re_layoutlmv2_xfund_zh
save_epoch_step: 2000 save_epoch_step: 2000
# evaluation is run every 10 iterations after the 0th iteration # evaluation is run every 10 iterations after the 0th iteration
eval_batch_step: [ 0, 19 ] eval_batch_step: [ 0, 57 ]
cal_metric_during_train: False cal_metric_during_train: False
save_inference_dir: save_inference_dir:
use_visualdl: False use_visualdl: False
seed: 2048 seed: 2048
infer_img: ppstructure/docs/vqa/input/zh_val_21.jpg infer_img: ppstructure/docs/vqa/input/zh_val_21.jpg
save_res_path: ./output/re/ save_res_path: ./output/re_layoutlmv2_xfund_zh/res/
Architecture: Architecture:
model_type: vqa model_type: vqa
......
Global:
use_gpu: True
epoch_num: &epoch_num 200
log_smooth_window: 10
print_batch_step: 10
save_model_dir: ./output/re_layoutxlm_funsd
save_epoch_step: 2000
# evaluation is run every 10 iterations after the 0th iteration
eval_batch_step: [ 0, 57 ]
cal_metric_during_train: False
save_inference_dir:
use_visualdl: False
seed: 2022
infer_img: train_data/FUNSD/testing_data/images/83624198.png
save_res_path: ./output/re_layoutxlm_funsd/res/
Architecture:
model_type: vqa
algorithm: &algorithm "LayoutXLM"
Transform:
Backbone:
name: LayoutXLMForRe
pretrained: True
checkpoints:
Loss:
name: LossFromOutput
key: loss
reduction: mean
Optimizer:
name: AdamW
beta1: 0.9
beta2: 0.999
clip_norm: 10
lr:
learning_rate: 0.00005
warmup_epoch: 10
regularizer:
name: L2
factor: 0.00000
PostProcess:
name: VQAReTokenLayoutLMPostProcess
Metric:
name: VQAReTokenMetric
main_indicator: hmean
Train:
dataset:
name: SimpleDataSet
data_dir: ./train_data/FUNSD/training_data/images/
label_file_list:
- ./train_data/FUNSD/train_v4.json
# - ./train_data/FUNSD/train.json
ratio_list: [ 1.0 ]
transforms:
- DecodeImage: # load image
img_mode: RGB
channel_first: False
- VQATokenLabelEncode: # Class handling label
contains_re: True
algorithm: *algorithm
class_path: &class_path ./train_data/FUNSD/class_list.txt
use_textline_bbox_info: &use_textline_bbox_info True
- VQATokenPad:
max_seq_len: &max_seq_len 512
return_attention_mask: True
- VQAReTokenRelation:
- VQAReTokenChunk:
max_seq_len: *max_seq_len
- Resize:
size: [224,224]
- NormalizeImage:
scale: 1
mean: [ 123.675, 116.28, 103.53 ]
std: [ 58.395, 57.12, 57.375 ]
order: 'hwc'
- ToCHWImage:
- KeepKeys:
# dataloader will return list in this order
keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image', 'entities', 'relations']
loader:
shuffle: False
drop_last: False
batch_size_per_card: 8
num_workers: 16
collate_fn: ListCollator
Eval:
dataset:
name: SimpleDataSet
data_dir: ./train_data/FUNSD/testing_data/images/
label_file_list:
- ./train_data/FUNSD/test_v4.json
# - ./train_data/FUNSD/test.json
transforms:
- DecodeImage: # load image
img_mode: RGB
channel_first: False
- VQATokenLabelEncode: # Class handling label
contains_re: True
algorithm: *algorithm
class_path: *class_path
use_textline_bbox_info: *use_textline_bbox_info
- VQATokenPad:
max_seq_len: *max_seq_len
return_attention_mask: True
- VQAReTokenRelation:
- VQAReTokenChunk:
max_seq_len: *max_seq_len
- Resize:
size: [224,224]
- NormalizeImage:
scale: 1
mean: [ 123.675, 116.28, 103.53 ]
std: [ 58.395, 57.12, 57.375 ]
order: 'hwc'
- ToCHWImage:
- KeepKeys:
# dataloader will return list in this order
keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image', 'entities', 'relations']
loader:
shuffle: False
drop_last: False
batch_size_per_card: 8
num_workers: 8
collate_fn: ListCollator
Global:
use_gpu: True
epoch_num: &epoch_num 200
log_smooth_window: 10
print_batch_step: 10
save_model_dir: ./output/ser_layoutlm_funsd
save_epoch_step: 2000
# evaluation is run every 10 iterations after the 0th iteration
eval_batch_step: [ 0, 57 ]
cal_metric_during_train: False
save_inference_dir:
use_visualdl: False
seed: 2022
infer_img: train_data/FUNSD/testing_data/images/83624198.png
save_res_path: ./output/ser_layoutlm_funsd/res/
Architecture:
model_type: vqa
algorithm: &algorithm "LayoutLM"
Transform:
Backbone:
name: LayoutLMForSer
pretrained: True
checkpoints:
num_classes: &num_classes 7
Loss:
name: VQASerTokenLayoutLMLoss
num_classes: *num_classes
Optimizer:
name: AdamW
beta1: 0.9
beta2: 0.999
lr:
name: Linear
learning_rate: 0.00005
epochs: *epoch_num
warmup_epoch: 2
regularizer:
name: L2
factor: 0.00000
PostProcess:
name: VQASerTokenLayoutLMPostProcess
class_path: &class_path ./train_data/FUNSD/class_list.txt
Metric:
name: VQASerTokenMetric
main_indicator: hmean
Train:
dataset:
name: SimpleDataSet
data_dir: ./train_data/FUNSD/training_data/images/
label_file_list:
- ./train_data/FUNSD/train.json
transforms:
- DecodeImage: # load image
img_mode: RGB
channel_first: False
- VQATokenLabelEncode: # Class handling label
contains_re: False
algorithm: *algorithm
class_path: *class_path
use_textline_bbox_info: &use_textline_bbox_info True
- VQATokenPad:
max_seq_len: &max_seq_len 512
return_attention_mask: True
- VQASerTokenChunk:
max_seq_len: *max_seq_len
- Resize:
size: [224,224]
- NormalizeImage:
scale: 1
mean: [ 123.675, 116.28, 103.53 ]
std: [ 58.395, 57.12, 57.375 ]
order: 'hwc'
- ToCHWImage:
- KeepKeys:
# dataloader will return list in this order
keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image', 'labels']
loader:
shuffle: True
drop_last: False
batch_size_per_card: 8
num_workers: 4
Eval:
dataset:
name: SimpleDataSet
data_dir: train_data/FUNSD/testing_data/images/
label_file_list:
- ./train_data/FUNSD/test.json
transforms:
- DecodeImage: # load image
img_mode: RGB
channel_first: False
- VQATokenLabelEncode: # Class handling label
contains_re: False
algorithm: *algorithm
class_path: *class_path
use_textline_bbox_info: *use_textline_bbox_info
- VQATokenPad:
max_seq_len: *max_seq_len
return_attention_mask: True
- VQASerTokenChunk:
max_seq_len: *max_seq_len
- Resize:
size: [224,224]
- NormalizeImage:
scale: 1
mean: [ 123.675, 116.28, 103.53 ]
std: [ 58.395, 57.12, 57.375 ]
order: 'hwc'
- ToCHWImage:
- KeepKeys:
# dataloader will return list in this order
keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image', 'labels']
loader:
shuffle: False
drop_last: False
batch_size_per_card: 8
num_workers: 4
Global:
use_gpu: True
epoch_num: &epoch_num 200
log_smooth_window: 10
print_batch_step: 10
save_model_dir: ./output/ser_layoutlm_sroie
save_epoch_step: 2000
# evaluation is run every 10 iterations after the 0th iteration
eval_batch_step: [ 0, 200 ]
cal_metric_during_train: False
save_inference_dir:
use_visualdl: False
seed: 2022
infer_img: train_data/SROIE/test/X00016469670.jpg
save_res_path: ./output/ser_layoutlm_sroie/res/
Architecture:
model_type: vqa
algorithm: &algorithm "LayoutLM"
Transform:
Backbone:
name: LayoutLMForSer
pretrained: True
checkpoints:
num_classes: &num_classes 9
Loss:
name: VQASerTokenLayoutLMLoss
num_classes: *num_classes
Optimizer:
name: AdamW
beta1: 0.9
beta2: 0.999
lr:
name: Linear
learning_rate: 0.00005
epochs: *epoch_num
warmup_epoch: 2
regularizer:
name: L2
factor: 0.00000
PostProcess:
name: VQASerTokenLayoutLMPostProcess
class_path: &class_path ./train_data/SROIE/class_list.txt
Metric:
name: VQASerTokenMetric
main_indicator: hmean
Train:
dataset:
name: SimpleDataSet
data_dir: ./train_data/SROIE/train
label_file_list:
- ./train_data/SROIE/train.txt
transforms:
- DecodeImage: # load image
img_mode: RGB
channel_first: False
- VQATokenLabelEncode: # Class handling label
contains_re: False
algorithm: *algorithm
class_path: *class_path
use_textline_bbox_info: &use_textline_bbox_info True
- VQATokenPad:
max_seq_len: &max_seq_len 512
return_attention_mask: True
- VQASerTokenChunk:
max_seq_len: *max_seq_len
- Resize:
size: [224,224]
- NormalizeImage:
scale: 1
mean: [ 123.675, 116.28, 103.53 ]
std: [ 58.395, 57.12, 57.375 ]
order: 'hwc'
- ToCHWImage:
- KeepKeys:
# dataloader will return list in this order
keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image', 'labels']
loader:
shuffle: True
drop_last: False
batch_size_per_card: 8
num_workers: 4
Eval:
dataset:
name: SimpleDataSet
data_dir: ./train_data/SROIE/test
label_file_list:
- ./train_data/SROIE/test.txt
transforms:
- DecodeImage: # load image
img_mode: RGB
channel_first: False
- VQATokenLabelEncode: # Class handling label
contains_re: False
algorithm: *algorithm
class_path: *class_path
use_textline_bbox_info: *use_textline_bbox_info
- VQATokenPad:
max_seq_len: *max_seq_len
return_attention_mask: True
- VQASerTokenChunk:
max_seq_len: *max_seq_len
- Resize:
size: [224,224]
- NormalizeImage:
scale: 1
mean: [ 123.675, 116.28, 103.53 ]
std: [ 58.395, 57.12, 57.375 ]
order: 'hwc'
- ToCHWImage:
- KeepKeys:
# dataloader will return list in this order
keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image', 'labels']
loader:
shuffle: False
drop_last: False
batch_size_per_card: 8
num_workers: 4
...@@ -3,16 +3,16 @@ Global: ...@@ -3,16 +3,16 @@ Global:
epoch_num: &epoch_num 200 epoch_num: &epoch_num 200
log_smooth_window: 10 log_smooth_window: 10
print_batch_step: 10 print_batch_step: 10
save_model_dir: ./output/ser_layoutlm/ save_model_dir: ./output/ser_layoutlm_xfund_zh
save_epoch_step: 2000 save_epoch_step: 2000
# evaluation is run every 10 iterations after the 0th iteration # evaluation is run every 10 iterations after the 0th iteration
eval_batch_step: [ 0, 19 ] eval_batch_step: [ 0, 57 ]
cal_metric_during_train: False cal_metric_during_train: False
save_inference_dir: save_inference_dir:
use_visualdl: False use_visualdl: False
seed: 2022 seed: 2022
infer_img: ppstructure/docs/vqa/input/zh_val_42.jpg infer_img: ppstructure/docs/vqa/input/zh_val_42.jpg
save_res_path: ./output/ser/ save_res_path: ./output/ser_layoutlm_xfund_zh/res/
Architecture: Architecture:
model_type: vqa model_type: vqa
......
Global:
use_gpu: True
epoch_num: &epoch_num 200
log_smooth_window: 10
print_batch_step: 10
save_model_dir: ./output/ser_layoutlmv2_funsd
save_epoch_step: 2000
# evaluation is run every 10 iterations after the 0th iteration
eval_batch_step: [ 0, 100 ]
cal_metric_during_train: False
save_inference_dir:
use_visualdl: False
seed: 2022
infer_img: train_data/FUNSD/testing_data/images/83624198.png
save_res_path: ./output/ser_layoutlmv2_funsd/res/
Architecture:
model_type: vqa
algorithm: &algorithm "LayoutLMv2"
Transform:
Backbone:
name: LayoutLMv2ForSer
pretrained: True
checkpoints:
num_classes: &num_classes 7
Loss:
name: VQASerTokenLayoutLMLoss
num_classes: *num_classes
Optimizer:
name: AdamW
beta1: 0.9
beta2: 0.999
lr:
name: Linear
learning_rate: 0.00005
epochs: *epoch_num
warmup_epoch: 2
regularizer:
name: L2
factor: 0.00000
PostProcess:
name: VQASerTokenLayoutLMPostProcess
class_path: &class_path train_data/FUNSD/class_list.txt
Metric:
name: VQASerTokenMetric
main_indicator: hmean
Train:
dataset:
name: SimpleDataSet
data_dir: ./train_data/FUNSD/training_data/images/
label_file_list:
- ./train_data/FUNSD/train.json
transforms:
- DecodeImage: # load image
img_mode: RGB
channel_first: False
- VQATokenLabelEncode: # Class handling label
contains_re: False
algorithm: *algorithm
class_path: *class_path
- VQATokenPad:
max_seq_len: &max_seq_len 512
return_attention_mask: True
- VQASerTokenChunk:
max_seq_len: *max_seq_len
- Resize:
size: [224,224]
- NormalizeImage:
scale: 1
mean: [ 123.675, 116.28, 103.53 ]
std: [ 58.395, 57.12, 57.375 ]
order: 'hwc'
- ToCHWImage:
- KeepKeys:
# dataloader will return list in this order
keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image', 'labels']
loader:
shuffle: True
drop_last: False
batch_size_per_card: 8
num_workers: 4
Eval:
dataset:
name: SimpleDataSet
data_dir: ./train_data/FUNSD/testing_data/images/
label_file_list:
- ./train_data/FUNSD/test.json
transforms:
- DecodeImage: # load image
img_mode: RGB
channel_first: False
- VQATokenLabelEncode: # Class handling label
contains_re: False
algorithm: *algorithm
class_path: *class_path
- VQATokenPad:
max_seq_len: *max_seq_len
return_attention_mask: True
- VQASerTokenChunk:
max_seq_len: *max_seq_len
- Resize:
size: [224,224]
- NormalizeImage:
scale: 1
mean: [ 123.675, 116.28, 103.53 ]
std: [ 58.395, 57.12, 57.375 ]
order: 'hwc'
- ToCHWImage:
- KeepKeys:
# dataloader will return list in this order
keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image', 'labels']
loader:
shuffle: False
drop_last: False
batch_size_per_card: 8
num_workers: 4
Global:
use_gpu: True
epoch_num: &epoch_num 200
log_smooth_window: 10
print_batch_step: 10
save_model_dir: ./output/ser_layoutlmv2_sroie
save_epoch_step: 2000
# evaluation is run every 10 iterations after the 0th iteration
eval_batch_step: [ 0, 200 ]
cal_metric_during_train: False
save_inference_dir:
use_visualdl: False
seed: 2022
infer_img: train_data/SROIE/test/X00016469670.jpg
save_res_path: ./output/ser_layoutlmv2_sroie/res/
Architecture:
model_type: vqa
algorithm: &algorithm "LayoutLMv2"
Transform:
Backbone:
name: LayoutLMv2ForSer
pretrained: True
checkpoints:
num_classes: &num_classes 9
Loss:
name: VQASerTokenLayoutLMLoss
num_classes: *num_classes
Optimizer:
name: AdamW
beta1: 0.9
beta2: 0.999
lr:
name: Linear
learning_rate: 0.00005
epochs: *epoch_num
warmup_epoch: 2
regularizer:
name: L2
factor: 0.00000
PostProcess:
name: VQASerTokenLayoutLMPostProcess
class_path: &class_path ./train_data/SROIE/class_list.txt
Metric:
name: VQASerTokenMetric
main_indicator: hmean
Train:
dataset:
name: SimpleDataSet
data_dir: ./train_data/SROIE/train
label_file_list:
- ./train_data/SROIE/train.txt
transforms:
- DecodeImage: # load image
img_mode: RGB
channel_first: False
- VQATokenLabelEncode: # Class handling label
contains_re: False
algorithm: *algorithm
class_path: *class_path
- VQATokenPad:
max_seq_len: &max_seq_len 512
return_attention_mask: True
- VQASerTokenChunk:
max_seq_len: *max_seq_len
- Resize:
size: [224,224]
- NormalizeImage:
scale: 1
mean: [ 123.675, 116.28, 103.53 ]
std: [ 58.395, 57.12, 57.375 ]
order: 'hwc'
- ToCHWImage:
- KeepKeys:
# dataloader will return list in this order
keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image', 'labels']
loader:
shuffle: True
drop_last: False
batch_size_per_card: 8
num_workers: 4
Eval:
dataset:
name: SimpleDataSet
data_dir: ./train_data/SROIE/test
label_file_list:
- ./train_data/SROIE/test.txt
transforms:
- DecodeImage: # load image
img_mode: RGB
channel_first: False
- VQATokenLabelEncode: # Class handling label
contains_re: False
algorithm: *algorithm
class_path: *class_path
- VQATokenPad:
max_seq_len: *max_seq_len
return_attention_mask: True
- VQASerTokenChunk:
max_seq_len: *max_seq_len
- Resize:
size: [224,224]
- NormalizeImage:
scale: 1
mean: [ 123.675, 116.28, 103.53 ]
std: [ 58.395, 57.12, 57.375 ]
order: 'hwc'
- ToCHWImage:
- KeepKeys:
# dataloader will return list in this order
keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image', 'labels']
loader:
shuffle: False
drop_last: False
batch_size_per_card: 8
num_workers: 4
...@@ -3,7 +3,7 @@ Global: ...@@ -3,7 +3,7 @@ Global:
epoch_num: &epoch_num 200 epoch_num: &epoch_num 200
log_smooth_window: 10 log_smooth_window: 10
print_batch_step: 10 print_batch_step: 10
save_model_dir: ./output/ser_layoutlmv2/ save_model_dir: ./output/ser_layoutlmv2_xfund_zh/
save_epoch_step: 2000 save_epoch_step: 2000
# evaluation is run every 10 iterations after the 0th iteration # evaluation is run every 10 iterations after the 0th iteration
eval_batch_step: [ 0, 19 ] eval_batch_step: [ 0, 19 ]
...@@ -12,7 +12,7 @@ Global: ...@@ -12,7 +12,7 @@ Global:
use_visualdl: False use_visualdl: False
seed: 2022 seed: 2022
infer_img: ppstructure/docs/vqa/input/zh_val_42.jpg infer_img: ppstructure/docs/vqa/input/zh_val_42.jpg
save_res_path: ./output/ser/ save_res_path: ./output/ser_layoutlmv2_xfund_zh/res/
Architecture: Architecture:
model_type: vqa model_type: vqa
......
Global:
use_gpu: True
epoch_num: &epoch_num 200
log_smooth_window: 10
print_batch_step: 10
save_model_dir: ./output/ser_layoutxlm_funsd
save_epoch_step: 2000
# evaluation is run every 10 iterations after the 0th iteration
eval_batch_step: [ 0, 57 ]
cal_metric_during_train: False
save_inference_dir:
use_visualdl: False
seed: 2022
infer_img: train_data/FUNSD/testing_data/images/83624198.png
save_res_path: output/ser_layoutxlm_funsd/res/
Architecture:
model_type: vqa
algorithm: &algorithm "LayoutXLM"
Transform:
Backbone:
name: LayoutXLMForSer
pretrained: True
checkpoints:
num_classes: &num_classes 7
Loss:
name: VQASerTokenLayoutLMLoss
num_classes: *num_classes
Optimizer:
name: AdamW
beta1: 0.9
beta2: 0.999
lr:
name: Linear
learning_rate: 0.00005
epochs: *epoch_num
warmup_epoch: 2
regularizer:
name: L2
factor: 0.00000
PostProcess:
name: VQASerTokenLayoutLMPostProcess
class_path: &class_path ./train_data/FUNSD/class_list.txt
Metric:
name: VQASerTokenMetric
main_indicator: hmean
Train:
dataset:
name: SimpleDataSet
data_dir: ./train_data/FUNSD/training_data/images/
label_file_list:
- ./train_data/FUNSD/train.json
ratio_list: [ 1.0 ]
transforms:
- DecodeImage: # load image
img_mode: RGB
channel_first: False
- VQATokenLabelEncode: # Class handling label
contains_re: False
algorithm: *algorithm
class_path: *class_path
- VQATokenPad:
max_seq_len: &max_seq_len 512
return_attention_mask: True
- VQASerTokenChunk:
max_seq_len: *max_seq_len
- Resize:
size: [224,224]
- NormalizeImage:
scale: 1
mean: [ 123.675, 116.28, 103.53 ]
std: [ 58.395, 57.12, 57.375 ]
order: 'hwc'
- ToCHWImage:
- KeepKeys:
# dataloader will return list in this order
keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image', 'labels']
loader:
shuffle: True
drop_last: False
batch_size_per_card: 8
num_workers: 4
Eval:
dataset:
name: SimpleDataSet
data_dir: train_data/FUNSD/testing_data/images/
label_file_list:
- ./train_data/FUNSD/test.json
transforms:
- DecodeImage: # load image
img_mode: RGB
channel_first: False
- VQATokenLabelEncode: # Class handling label
contains_re: False
algorithm: *algorithm
class_path: *class_path
- VQATokenPad:
max_seq_len: *max_seq_len
return_attention_mask: True
- VQASerTokenChunk:
max_seq_len: *max_seq_len
- Resize:
size: [224,224]
- NormalizeImage:
scale: 1
mean: [ 123.675, 116.28, 103.53 ]
std: [ 58.395, 57.12, 57.375 ]
order: 'hwc'
- ToCHWImage:
- KeepKeys:
# dataloader will return list in this order
keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image', 'labels']
loader:
shuffle: False
drop_last: False
batch_size_per_card: 8
num_workers: 4
Global:
use_gpu: True
epoch_num: &epoch_num 200
log_smooth_window: 10
print_batch_step: 10
save_model_dir: ./output/ser_layoutxlm_sroie
save_epoch_step: 2000
# evaluation is run every 10 iterations after the 0th iteration
eval_batch_step: [ 0, 200 ]
cal_metric_during_train: False
save_inference_dir:
use_visualdl: False
seed: 2022
infer_img: train_data/SROIE/test/X00016469670.jpg
save_res_path: res_img_aug_with_gt
Architecture:
model_type: vqa
algorithm: &algorithm "LayoutXLM"
Transform:
Backbone:
name: LayoutXLMForSer
pretrained: True
checkpoints:
num_classes: &num_classes 9
Loss:
name: VQASerTokenLayoutLMLoss
num_classes: *num_classes
Optimizer:
name: AdamW
beta1: 0.9
beta2: 0.999
lr:
name: Linear
learning_rate: 0.00005
epochs: *epoch_num
warmup_epoch: 2
regularizer:
name: L2
factor: 0.00000
PostProcess:
name: VQASerTokenLayoutLMPostProcess
class_path: &class_path ./train_data/SROIE/class_list.txt
Metric:
name: VQASerTokenMetric
main_indicator: hmean
Train:
dataset:
name: SimpleDataSet
data_dir: ./train_data/SROIE/train
label_file_list:
- ./train_data/SROIE/train.txt
ratio_list: [ 1.0 ]
transforms:
- DecodeImage: # load image
img_mode: RGB
channel_first: False
- VQATokenLabelEncode: # Class handling label
contains_re: False
algorithm: *algorithm
class_path: *class_path
- VQATokenPad:
max_seq_len: &max_seq_len 512
return_attention_mask: True
- VQASerTokenChunk:
max_seq_len: *max_seq_len
- Resize:
size: [224,224]
- NormalizeImage:
scale: 1
mean: [ 123.675, 116.28, 103.53 ]
std: [ 58.395, 57.12, 57.375 ]
order: 'hwc'
- ToCHWImage:
- KeepKeys:
# dataloader will return list in this order
keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image', 'labels']
loader:
shuffle: True
drop_last: False
batch_size_per_card: 8
num_workers: 4
Eval:
dataset:
name: SimpleDataSet
data_dir: train_data/SROIE/test
label_file_list:
- ./train_data/SROIE/test.txt
transforms:
- DecodeImage: # load image
img_mode: RGB
channel_first: False
- VQATokenLabelEncode: # Class handling label
contains_re: False
algorithm: *algorithm
class_path: *class_path
- VQATokenPad:
max_seq_len: *max_seq_len
return_attention_mask: True
- VQASerTokenChunk:
max_seq_len: *max_seq_len
- Resize:
size: [224,224]
- NormalizeImage:
scale: 1
mean: [ 123.675, 116.28, 103.53 ]
std: [ 58.395, 57.12, 57.375 ]
order: 'hwc'
- ToCHWImage:
- KeepKeys:
# dataloader will return list in this order
keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image', 'labels']
loader:
shuffle: False
drop_last: False
batch_size_per_card: 8
num_workers: 4
Global:
use_gpu: True
epoch_num: &epoch_num 100
log_smooth_window: 10
print_batch_step: 10
save_model_dir: ./output/ser_layoutxlm_wildreceipt
save_epoch_step: 2000
# evaluation is run every 10 iterations after the 0th iteration
eval_batch_step: [ 0, 200 ]
cal_metric_during_train: False
save_inference_dir:
use_visualdl: False
seed: 2022
infer_img: train_data//wildreceipt/image_files/Image_12/10/845be0dd6f5b04866a2042abd28d558032ef2576.jpeg
save_res_path: ./output/ser_layoutxlm_wildreceipt/res
Architecture:
model_type: vqa
algorithm: &algorithm "LayoutXLM"
Transform:
Backbone:
name: LayoutXLMForSer
pretrained: True
checkpoints:
num_classes: &num_classes 51
Loss:
name: VQASerTokenLayoutLMLoss
num_classes: *num_classes
Optimizer:
name: AdamW
beta1: 0.9
beta2: 0.999
lr:
name: Linear
learning_rate: 0.00005
epochs: *epoch_num
warmup_epoch: 2
regularizer:
name: L2
factor: 0.00000
PostProcess:
name: VQASerTokenLayoutLMPostProcess
class_path: &class_path ./train_data/wildreceipt/class_list.txt
Metric:
name: VQASerTokenMetric
main_indicator: hmean
Train:
dataset:
name: SimpleDataSet
data_dir: ./train_data/wildreceipt/
label_file_list:
- ./train_data/wildreceipt/wildreceipt_train.txt
ratio_list: [ 1.0 ]
transforms:
- DecodeImage: # load image
img_mode: RGB
channel_first: False
- VQATokenLabelEncode: # Class handling label
contains_re: False
algorithm: *algorithm
class_path: *class_path
- VQATokenPad:
max_seq_len: &max_seq_len 512
return_attention_mask: True
- VQASerTokenChunk:
max_seq_len: *max_seq_len
- Resize:
size: [224,224]
- NormalizeImage:
scale: 1
mean: [ 123.675, 116.28, 103.53 ]
std: [ 58.395, 57.12, 57.375 ]
order: 'hwc'
- ToCHWImage:
- KeepKeys:
# dataloader will return list in this order
keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image', 'labels']
loader:
shuffle: True
drop_last: False
batch_size_per_card: 8
num_workers: 4
Eval:
dataset:
name: SimpleDataSet
data_dir: train_data/wildreceipt
label_file_list:
- ./train_data/wildreceipt/wildreceipt_test.txt
transforms:
- DecodeImage: # load image
img_mode: RGB
channel_first: False
- VQATokenLabelEncode: # Class handling label
contains_re: False
algorithm: *algorithm
class_path: *class_path
- VQATokenPad:
max_seq_len: *max_seq_len
return_attention_mask: True
- VQASerTokenChunk:
max_seq_len: *max_seq_len
- Resize:
size: [224,224]
- NormalizeImage:
scale: 1
mean: [ 123.675, 116.28, 103.53 ]
std: [ 58.395, 57.12, 57.375 ]
order: 'hwc'
- ToCHWImage:
- KeepKeys:
# dataloader will return list in this order
keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image', 'labels']
loader:
shuffle: False
drop_last: False
batch_size_per_card: 8
num_workers: 4
...@@ -3,7 +3,7 @@ Global: ...@@ -3,7 +3,7 @@ Global:
epoch_num: &epoch_num 200 epoch_num: &epoch_num 200
log_smooth_window: 10 log_smooth_window: 10
print_batch_step: 10 print_batch_step: 10
save_model_dir: ./output/ser_layoutxlm/ save_model_dir: ./output/ser_layoutxlm_xfund_zh
save_epoch_step: 2000 save_epoch_step: 2000
# evaluation is run every 10 iterations after the 0th iteration # evaluation is run every 10 iterations after the 0th iteration
eval_batch_step: [ 0, 19 ] eval_batch_step: [ 0, 19 ]
...@@ -12,7 +12,7 @@ Global: ...@@ -12,7 +12,7 @@ Global:
use_visualdl: False use_visualdl: False
seed: 2022 seed: 2022
infer_img: ppstructure/docs/vqa/input/zh_val_42.jpg infer_img: ppstructure/docs/vqa/input/zh_val_42.jpg
save_res_path: ./output/ser save_res_path: ./output/ser_layoutxlm_xfund_zh/res
Architecture: Architecture:
model_type: vqa model_type: vqa
......
...@@ -15,20 +15,24 @@ ...@@ -15,20 +15,24 @@
<!--- specific language governing permissions and limitations --> <!--- specific language governing permissions and limitations -->
<!--- under the License. --> <!--- under the License. -->
Running PaddleOCR text recognition model via TVM on bare metal Arm(R) Cortex(R)-M55 CPU and CMSIS-NN English | [简体中文](README_ch.md)
===============================================================
This folder contains an example of how to use TVM to run a PaddleOCR model Running PaddleOCR text recognition model on bare metal Arm(R) Cortex(R)-M55 CPU using Arm Virtual Hardware
on bare metal Cortex(R)-M55 CPU and CMSIS-NN. ======================================================================
Prerequisites This folder contains an example of how to run a PaddleOCR model on bare metal [Cortex(R)-M55 CPU](https://www.arm.com/products/silicon-ip-cpu/cortex-m/cortex-m55) using [Arm Virtual Hardware](https://www.arm.com/products/development-tools/simulation/virtual-hardware).
Running environment and prerequisites
------------- -------------
If the demo is run in the ci_cpu Docker container provided with TVM, then the following Case 1: If the demo is run in Arm Virtual Hardware Amazon Machine Image(AMI) instance hosted by [AWS](https://aws.amazon.com/marketplace/pp/prodview-urbpq7yo5va7g?sr=0-1&ref_=beagle&applicationId=AWSMPContessa)/[AWS China](https://awsmarketplace.amazonaws.cn/marketplace/pp/prodview-2y7nefntbmybu), the following software will be installed through [configure_avh.sh](./configure_avh.sh) script. It will install automatically when you run the application through [run_demo.sh](./run_demo.sh) script.
software will already be installed. You can refer to this [guide](https://arm-software.github.io/AVH/main/examples/html/MicroSpeech.html#amilaunch) to launch an Arm Virtual Hardware AMI instance.
Case 2: If the demo is run in the [ci_cpu Docker container](https://github.com/apache/tvm/blob/main/docker/Dockerfile.ci_cpu) provided with [TVM](https://github.com/apache/tvm), then the following software will already be installed.
If the demo is not run in the ci_cpu Docker container, then you will need the following: Case 3: If the demo is not run in the ci_cpu Docker container, then you will need the following:
- Software required to build and run the demo (These can all be installed by running - Software required to build and run the demo (These can all be installed by running
https://github.com/apache/tvm/blob/main/docker/install/ubuntu_install_ethosu_driver_stack.sh .) tvm/docker/install/ubuntu_install_ethosu_driver_stack.sh.)
- [Fixed Virtual Platform (FVP) based on Arm(R) Corstone(TM)-300 software](https://developer.arm.com/tools-and-software/open-source-software/arm-platforms-software/arm-ecosystem-fvps) - [Fixed Virtual Platform (FVP) based on Arm(R) Corstone(TM)-300 software](https://developer.arm.com/tools-and-software/open-source-software/arm-platforms-software/arm-ecosystem-fvps)
- [cmake 3.19.5](https://github.com/Kitware/CMake/releases/) - [cmake 3.19.5](https://github.com/Kitware/CMake/releases/)
- [GCC toolchain from Arm(R)](https://developer.arm.com/-/media/Files/downloads/gnu-rm/10-2020q4/gcc-arm-none-eabi-10-2020-q4-major-x86_64-linux.tar.bz2) - [GCC toolchain from Arm(R)](https://developer.arm.com/-/media/Files/downloads/gnu-rm/10-2020q4/gcc-arm-none-eabi-10-2020-q4-major-x86_64-linux.tar.bz2)
...@@ -40,19 +44,22 @@ If the demo is not run in the ci_cpu Docker container, then you will need the fo ...@@ -40,19 +44,22 @@ If the demo is not run in the ci_cpu Docker container, then you will need the fo
pip install -r ./requirements.txt pip install -r ./requirements.txt
``` ```
In case2 and case3:
You will need to update your PATH environment variable to include the path to cmake 3.19.5 and the FVP.
For example if you've installed these in ```/opt/arm``` , then you would do the following:
```bash
export PATH=/opt/arm/FVP_Corstone_SSE-300/models/Linux64_GCC-6.4:/opt/arm/cmake/bin:$PATH
```
You will also need TVM which can either be: You will also need TVM which can either be:
- Installed from TLCPack(see [TLCPack](https://tlcpack.ai/))
- Built from source (see [Install from Source](https://tvm.apache.org/docs/install/from_source.html)) - Built from source (see [Install from Source](https://tvm.apache.org/docs/install/from_source.html))
- When building from source, the following need to be set in config.cmake: - When building from source, the following need to be set in config.cmake:
- set(USE_CMSISNN ON) - set(USE_CMSISNN ON)
- set(USE_MICRO ON) - set(USE_MICRO ON)
- set(USE_LLVM ON) - set(USE_LLVM ON)
- Installed from TLCPack nightly(see [TLCPack](https://tlcpack.ai/))
You will need to update your PATH environment variable to include the path to cmake 3.19.5 and the FVP.
For example if you've installed these in ```/opt/arm``` , then you would do the following:
```bash
export PATH=/opt/arm/FVP_Corstone_SSE-300/models/Linux64_GCC-6.4:/opt/arm/cmake/bin:$PATH
```
Running the demo application Running the demo application
---------------------------- ----------------------------
...@@ -62,6 +69,12 @@ Type the following command to run the bare metal text recognition application ([ ...@@ -62,6 +69,12 @@ Type the following command to run the bare metal text recognition application ([
./run_demo.sh ./run_demo.sh
``` ```
If you are not able to use Arm Virtual Hardware Amazon Machine Image(AMI) instance hosted by AWS/AWS China, specify argument --enable_FVP to 1 to make the application run on local Fixed Virtual Platforms (FVPs) executables.
```bash
./run_demo.sh --enable_FVP 1
```
If the Ethos(TM)-U platform and/or CMSIS have not been installed in /opt/arm/ethosu then If the Ethos(TM)-U platform and/or CMSIS have not been installed in /opt/arm/ethosu then
the locations for these can be specified as arguments to run_demo.sh, for example: the locations for these can be specified as arguments to run_demo.sh, for example:
...@@ -70,13 +83,14 @@ the locations for these can be specified as arguments to run_demo.sh, for exampl ...@@ -70,13 +83,14 @@ the locations for these can be specified as arguments to run_demo.sh, for exampl
--ethosu_platform_path /home/tvm-user/ethosu/core_platform --ethosu_platform_path /home/tvm-user/ethosu/core_platform
``` ```
This will: With [run_demo.sh](./run_demo.sh) to run the demo application, it will:
- Set up running environment by installing the required prerequisites automatically if running in Arm Virtual Hardware Amazon AMI instance(not specify --enable_FVP to 1)
- Download a PaddleOCR text recognition model - Download a PaddleOCR text recognition model
- Use tvmc to compile the text recognition model for Cortex(R)-M55 CPU and CMSIS-NN - Use tvmc to compile the text recognition model for Cortex(R)-M55 CPU and CMSIS-NN
- Create a C header file inputs.c containing the image data as a C array - Create a C header file inputs.c containing the image data as a C array
- Create a C header file outputs.c containing a C array where the output of inference will be stored - Create a C header file outputs.c containing a C array where the output of inference will be stored
- Build the demo application - Build the demo application
- Run the demo application on a Fixed Virtual Platform (FVP) based on Arm(R) Corstone(TM)-300 software - Run the demo application on a Arm Virtual Hardware based on Arm(R) Corstone(TM)-300 software
- The application will report the text on the image and the corresponding score. - The application will report the text on the image and the corresponding score.
Using your own image Using your own image
...@@ -92,9 +106,11 @@ python3 ./convert_image.py path/to/image ...@@ -92,9 +106,11 @@ python3 ./convert_image.py path/to/image
Model description Model description
----------------- -----------------
In this demo, the model we use is an English recognition model based on [PP-OCRv3](https://github.com/PaddlePaddle/PaddleOCR/blob/dygraph/doc/doc_ch/PP-OCRv3_introduction.md). PP-OCRv3 is the third version of the PP-OCR series model released by [PaddleOCR](https://github.com/PaddlePaddle/PaddleOCR). This series of models has the following features: The example is built on [PP-OCRv3](https://github.com/PaddlePaddle/PaddleOCR/blob/dygraph/doc/doc_ch/PP-OCRv3_introduction.md) English recognition model released by [PaddleOCR](https://github.com/PaddlePaddle/PaddleOCR). Since Arm(R) Cortex(R)-M55 CPU does not support rnn operator, we delete the unsupported operator based on the PP-OCRv3 text recognition model to obtain the current 2.7M English recognition model.
PP-OCRv3 is the third version of the PP-OCR series model. This series of models has the following features:
- PP-OCRv3: ultra-lightweight OCR system: detection (3.6M) + direction classifier (1.4M) + recognition (12M) = 17.0M - PP-OCRv3: ultra-lightweight OCR system: detection (3.6M) + direction classifier (1.4M) + recognition (12M) = 17.0M
- Support more than 80 kinds of multi-language recognition models, including English, Chinese, French, German, Arabic, Korean, Japanese and so on. For details - Support more than 80 kinds of multi-language recognition models, including English, Chinese, French, German, Arabic, Korean, Japanese and so on. For details
- Support vertical text recognition, and long text recognition - Support vertical text recognition, and long text recognition
The text recognition model in PP-OCRv3 supports more than 80 languages. In the process of model development, since Arm(R) Cortex(R)-M55 CPU does not support rnn operator, we delete the unsupported operator based on the PP-OCRv3 text recognition model to obtain the current model.
\ No newline at end of file
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
<!--- KIND, either express or implied. See the License for the --> <!--- KIND, either express or implied. See the License for the -->
<!--- specific language governing permissions and limitations --> <!--- specific language governing permissions and limitations -->
<!--- under the License. --> <!--- under the License. -->
[English](README.md) | 简体中文
通过TVM在 Arm(R) Cortex(R)-M55 CPU 上运行 PaddleOCR文 本能识别模型 通过TVM在 Arm(R) Cortex(R)-M55 CPU 上运行 PaddleOCR文 本能识别模型
=============================================================== ===============================================================
...@@ -85,9 +86,9 @@ export PATH=/opt/arm/FVP_Corstone_SSE-300/models/Linux64_GCC-6.4:/opt/arm/cmake/ ...@@ -85,9 +86,9 @@ export PATH=/opt/arm/FVP_Corstone_SSE-300/models/Linux64_GCC-6.4:/opt/arm/cmake/
模型描述 模型描述
----------------- -----------------
在这个demo中,我们使用的模型是基于[PP-OCRv3](https://github.com/PaddlePaddle/PaddleOCR/blob/dygraph/doc/doc_ch/PP-OCRv3_introduction.md)的英文识别模型。 PP-OCRv3是[PaddleOCR](https://github.com/PaddlePaddle/PaddleOCR)发布的PP-OCR系列模型的第三个版本。 该系列模型具有以下特点: 在这个demo中,我们使用的模型是基于[PP-OCRv3](https://github.com/PaddlePaddle/PaddleOCR/blob/dygraph/doc/doc_ch/PP-OCRv3_introduction.md)的英文识别模型。由于Arm(R) Cortex(R)-M55 CPU不支持rnn算子,我们在PP-OCRv3原始文本识别模型的基础上进行适配,最终模型大小为2.7M。
PP-OCRv3是[PaddleOCR](https://github.com/PaddlePaddle/PaddleOCR)发布的PP-OCR系列模型的第三个版本,该系列模型具有以下特点:
- 超轻量级OCR系统:检测(3.6M)+方向分类器(1.4M)+识别(12M)=17.0M。 - 超轻量级OCR系统:检测(3.6M)+方向分类器(1.4M)+识别(12M)=17.0M。
- 支持80多种多语言识别模型,包括英文、中文、法文、德文、阿拉伯文、韩文、日文等。 - 支持80多种多语言识别模型,包括英文、中文、法文、德文、阿拉伯文、韩文、日文等。
- 支持竖排文本识别,长文本识别。 - 支持竖排文本识别,长文本识别。
PP-OCRv3 中的文本识别模型支持 80 多种语言。 在模型开发过程中,由于Arm(R) Cortex(R)-M55 CPU不支持rnn算子,我们在PP-OCRv3文本识别模型的基础上删除了不支持的算子,得到当前模型。
\ No newline at end of file
#!/bin/bash
# Copyright (c) 2022 Arm Limited and Contributors. All rights reserved.
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
set -e
set -u
set -o pipefail
# Show usage
function show_usage() {
cat <<EOF
Usage: Set up running environment by installing the required prerequisites.
-h, --help
Display this help message.
EOF
}
if [ "$#" -eq 1 ] && [ "$1" == "--help" -o "$1" == "-h" ]; then
show_usage
exit 0
elif [ "$#" -ge 1 ]; then
show_usage
exit 1
fi
echo -e "\e[36mStart setting up running environment\e[0m"
# Install CMSIS
echo -e "\e[36mStart installing CMSIS\e[0m"
CMSIS_PATH="/opt/arm/ethosu/cmsis"
mkdir -p "${CMSIS_PATH}"
CMSIS_SHA="977abe9849781a2e788b02282986480ff4e25ea6"
CMSIS_SHASUM="86c88d9341439fbb78664f11f3f25bc9fda3cd7de89359324019a4d87d169939eea85b7fdbfa6ad03aa428c6b515ef2f8cd52299ce1959a5444d4ac305f934cc"
CMSIS_URL="http://github.com/ARM-software/CMSIS_5/archive/${CMSIS_SHA}.tar.gz"
DOWNLOAD_PATH="/tmp/${CMSIS_SHA}.tar.gz"
wget ${CMSIS_URL} -O "${DOWNLOAD_PATH}"
echo "$CMSIS_SHASUM" ${DOWNLOAD_PATH} | sha512sum -c
tar -xf "${DOWNLOAD_PATH}" -C "${CMSIS_PATH}" --strip-components=1
touch "${CMSIS_PATH}"/"${CMSIS_SHA}".sha
echo -e "\e[36mCMSIS Installation SUCCESS\e[0m"
# Install Arm(R) Ethos(TM)-U NPU driver stack
echo -e "\e[36mStart installing Arm(R) Ethos(TM)-U NPU driver stack\e[0m"
git clone "https://review.mlplatform.org/ml/ethos-u/ethos-u-core-platform" /opt/arm/ethosu/core_platform
cd /opt/arm/ethosu/core_platform
git checkout tags/"21.11"
echo -e "\e[36mArm(R) Ethos(TM)-U Core Platform Installation SUCCESS\e[0m"
# Install Arm(R) GNU Toolchain
echo -e "\e[36mStart installing Arm(R) GNU Toolchain\e[0m"
mkdir -p /opt/arm/gcc-arm-none-eabi
export gcc_arm_url='https://developer.arm.com/-/media/Files/downloads/gnu-rm/10-2020q4/gcc-arm-none-eabi-10-2020-q4-major-x86_64-linux.tar.bz2?revision=ca0cbf9c-9de2-491c-ac48-898b5bbc0443&la=en&hash=68760A8AE66026BCF99F05AC017A6A50C6FD832A'
curl --retry 64 -sSL ${gcc_arm_url} | tar -C /opt/arm/gcc-arm-none-eabi --strip-components=1 -jx
export PATH=/opt/arm/gcc-arm-none-eabi/bin:$PATH
arm-none-eabi-gcc --version
arm-none-eabi-g++ --version
echo -e "\e[36mArm(R) Arm(R) GNU Toolchain Installation SUCCESS\e[0m"
# Install TVM from TLCPack
echo -e "\e[36mStart installing TVM\e[0m"
pip install tlcpack-nightly -f https://tlcpack.ai/wheels
echo -e "\e[36mTVM Installation SUCCESS\e[0m"
\ No newline at end of file
...@@ -24,18 +24,28 @@ function show_usage() { ...@@ -24,18 +24,28 @@ function show_usage() {
cat <<EOF cat <<EOF
Usage: run_demo.sh Usage: run_demo.sh
-h, --help -h, --help
Display this help message. Display this help message.
--cmsis_path CMSIS_PATH --cmsis_path CMSIS_PATH
Set path to CMSIS. Set path to CMSIS.
--ethosu_platform_path ETHOSU_PLATFORM_PATH --ethosu_platform_path ETHOSU_PLATFORM_PATH
Set path to Arm(R) Ethos(TM)-U core platform. Set path to Arm(R) Ethos(TM)-U core platform.
--fvp_path FVP_PATH --fvp_path FVP_PATH
Set path to FVP. Set path to FVP.
--cmake_path --cmake_path
Set path to cmake. Set path to cmake.
--enable_FVP
Set 1 to run application on local Fixed Virtual Platforms (FVPs) executables.
EOF EOF
} }
# Configure environment variables
FVP_enable=0
export PATH=/opt/arm/gcc-arm-none-eabi/bin:$PATH
# Install python libraries
echo -e "\e[36mInstall python libraries\e[0m"
sudo pip install -r ./requirements.txt
# Parse arguments # Parse arguments
while (( $# )); do while (( $# )); do
case "$1" in case "$1" in
...@@ -91,6 +101,18 @@ while (( $# )); do ...@@ -91,6 +101,18 @@ while (( $# )); do
exit 1 exit 1
fi fi
;; ;;
--enable_FVP)
if [ $# -gt 1 ] && [ "$2" == "1" -o "$2" == "0" ];
then
FVP_enable="$2"
shift 2
else
echo 'ERROR: --enable_FVP requires a right argument 1 or 0' >&2
show_usage >&2
exit 1
fi
;;
-*|--*) -*|--*)
echo "Error: Unknown flag: $1" >&2 echo "Error: Unknown flag: $1" >&2
...@@ -100,17 +122,27 @@ while (( $# )); do ...@@ -100,17 +122,27 @@ while (( $# )); do
esac esac
done done
# Choose running environment: cloud(default) or local environment
Platform="VHT_Corstone_SSE-300_Ethos-U55"
if [ $FVP_enable == "1" ]; then
Platform="FVP_Corstone_SSE-300_Ethos-U55"
echo -e "\e[36mRun application on local Fixed Virtual Platforms (FVPs)\e[0m"
else
if [ ! -d "/opt/arm/" ]; then
sudo ./configure_avh.sh
fi
fi
# Directories # Directories
script_dir="$( cd "$( dirname "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )" script_dir="$( cd "$( dirname "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )"
# Make build directory # Make build directory
rm -rf build
make cleanall make cleanall
mkdir -p build mkdir -p build
cd build cd build
# Get PaddlePaddle inference model
echo -e "\e[36mDownload PaddlePaddle inference model\e[0m"
wget https://paddleocr.bj.bcebos.com/tvm/ocr_en.tar wget https://paddleocr.bj.bcebos.com/tvm/ocr_en.tar
tar -xf ocr_en.tar tar -xf ocr_en.tar
...@@ -144,9 +176,9 @@ cd ${script_dir} ...@@ -144,9 +176,9 @@ cd ${script_dir}
echo ${script_dir} echo ${script_dir}
make make
# Run demo executable on the FVP # Run demo executable on the AVH
FVP_Corstone_SSE-300_Ethos-U55 -C cpu0.CFGDTCMSZ=15 \ $Platform -C cpu0.CFGDTCMSZ=15 \
-C cpu0.CFGITCMSZ=15 -C mps3_board.uart0.out_file=\"-\" -C mps3_board.uart0.shutdown_tag=\"EXITTHESIM\" \ -C cpu0.CFGITCMSZ=15 -C mps3_board.uart0.out_file=\"-\" -C mps3_board.uart0.shutdown_tag=\"EXITTHESIM\" \
-C mps3_board.visualisation.disable-visualisation=1 -C mps3_board.telnetterminal0.start_telnet=0 \ -C mps3_board.visualisation.disable-visualisation=1 -C mps3_board.telnetterminal0.start_telnet=0 \
-C mps3_board.telnetterminal1.start_telnet=0 -C mps3_board.telnetterminal2.start_telnet=0 -C mps3_board.telnetterminal5.start_telnet=0 \ -C mps3_board.telnetterminal1.start_telnet=0 -C mps3_board.telnetterminal2.start_telnet=0 -C mps3_board.telnetterminal5.start_telnet=0 \
./build/demo ./build/demo --stat
\ No newline at end of file
...@@ -34,12 +34,13 @@ cv::Mat CrnnResizeImg(cv::Mat img, float wh_ratio, int rec_image_height) { ...@@ -34,12 +34,13 @@ cv::Mat CrnnResizeImg(cv::Mat img, float wh_ratio, int rec_image_height) {
resize_w = imgW; resize_w = imgW;
else else
resize_w = int(ceilf(imgH * ratio)); resize_w = int(ceilf(imgH * ratio));
cv::Mat resize_img;
cv::resize(img, resize_img, cv::Size(resize_w, imgH), 0.f, 0.f, cv::resize(img, resize_img, cv::Size(resize_w, imgH), 0.f, 0.f,
cv::INTER_LINEAR); cv::INTER_LINEAR);
cv::copyMakeBorder(resize_img, resize_img, 0, 0, 0, cv::copyMakeBorder(resize_img, resize_img, 0, 0, 0,
int(imgW - resize_img.cols), cv::BORDER_CONSTANT, int(imgW - resize_img.cols), cv::BORDER_CONSTANT,
{127, 127, 127}); {127, 127, 127});
return resize_img;
} }
std::vector<std::string> ReadDict(std::string path) { std::vector<std::string> ReadDict(std::string path) {
......
...@@ -474,7 +474,7 @@ void system(char **argv){ ...@@ -474,7 +474,7 @@ void system(char **argv){
std::vector<double> rec_times; std::vector<double> rec_times;
RunRecModel(boxes, srcimg, rec_predictor, rec_text, rec_text_score, RunRecModel(boxes, srcimg, rec_predictor, rec_text, rec_text_score,
charactor_dict, cls_predictor, use_direction_classify, &rec_times); charactor_dict, cls_predictor, use_direction_classify, &rec_times, rec_image_height);
//// visualization //// visualization
auto img_vis = Visualization(srcimg, boxes); auto img_vis = Visualization(srcimg, boxes);
......
...@@ -5,9 +5,10 @@ PaddleOCR将**持续新增**支持OCR领域前沿算法与模型,已支持的 ...@@ -5,9 +5,10 @@ PaddleOCR将**持续新增**支持OCR领域前沿算法与模型,已支持的
- [文本检测算法](./algorithm_overview.md#11-%E6%96%87%E6%9C%AC%E6%A3%80%E6%B5%8B%E7%AE%97%E6%B3%95) - [文本检测算法](./algorithm_overview.md#11-%E6%96%87%E6%9C%AC%E6%A3%80%E6%B5%8B%E7%AE%97%E6%B3%95)
- [文本识别算法](./algorithm_overview.md#12-%E6%96%87%E6%9C%AC%E8%AF%86%E5%88%AB%E7%AE%97%E6%B3%95) - [文本识别算法](./algorithm_overview.md#12-%E6%96%87%E6%9C%AC%E8%AF%86%E5%88%AB%E7%AE%97%E6%B3%95)
- [端到端算法](./algorithm_overview.md#2-%E6%96%87%E6%9C%AC%E8%AF%86%E5%88%AB%E7%AE%97%E6%B3%95) - [端到端算法](./algorithm_overview.md#2-%E6%96%87%E6%9C%AC%E8%AF%86%E5%88%AB%E7%AE%97%E6%B3%95)
- [表格识别]](./algorithm_overview.md#3-%E8%A1%A8%E6%A0%BC%E8%AF%86%E5%88%AB%E7%AE%97%E6%B3%95)
**欢迎广大开发者合作共建,贡献更多算法,合入有奖🎁!具体可查看[社区常规赛](https://github.com/PaddlePaddle/PaddleOCR/issues/4982)。** **欢迎广大开发者合作共建,贡献更多算法,合入有奖🎁!具体可查看[社区常规赛](https://github.com/PaddlePaddle/PaddleOCR/issues/4982)。**
新增算法可参考如下教程: 新增算法可参考如下教程:
- [使用PaddleOCR架构添加新算法](./add_new_algorithm.md) - [使用PaddleOCR架构添加新算法](./add_new_algorithm.md)
\ No newline at end of file
# FCENet # FCENet
- [1. 算法简介](#1) - [1. 算法简介](#1-算法简介)
- [2. 环境配置](#2) - [2. 环境配置](#2-环境配置)
- [3. 模型训练、评估、预测](#3) - [3. 模型训练、评估、预测](#3-模型训练评估预测)
- [3.1 训练](#3-1) - [4. 推理部署](#4-推理部署)
- [3.2 评估](#3-2) - [4.1 Python推理](#41-python推理)
- [3.3 预测](#3-3) - [4.2 C++推理](#42-c推理)
- [4. 推理部署](#4) - [4.3 Serving服务化部署](#43-serving服务化部署)
- [4.1 Python推理](#4-1) - [4.4 更多推理部署](#44-更多推理部署)
- [4.2 C++推理](#4-2) - [5. FAQ](#5-faq)
- [4.3 Serving服务化部署](#4-3) - [引用](#引用)
- [4.4 更多推理部署](#4-4)
- [5. FAQ](#5)
<a name="1"></a> <a name="1"></a>
## 1. 算法简介 ## 1. 算法简介
......
# OCR算法 # OCR算法
- [1. 两阶段算法](#1-两阶段算法) - [1. 两阶段算法](#1)
- [1.1 文本检测算法](#11-文本检测算法) - [1.1 文本检测算法](#11)
- [1.2 文本识别算法](#12-文本识别算法) - [1.2 文本识别算法](#12)
- [2. 端到端算法](#2-端到端算法) - [2. 端到端算法](#2)
- [3. 表格识别算法](#3)
本文给出了PaddleOCR已支持的OCR算法列表,以及每个算法在**英文公开数据集**上的模型和指标,主要用于算法简介和算法性能对比,更多包括中文在内的其他数据集上的模型请参考[PP-OCR v2.0 系列模型下载](./models_list.md) 本文给出了PaddleOCR已支持的OCR算法列表,以及每个算法在**英文公开数据集**上的模型和指标,主要用于算法简介和算法性能对比,更多包括中文在内的其他数据集上的模型请参考[PP-OCR v2.0 系列模型下载](./models_list.md)
...@@ -96,3 +97,16 @@ ...@@ -96,3 +97,16 @@
已支持的端到端OCR算法列表(戳链接获取使用教程): 已支持的端到端OCR算法列表(戳链接获取使用教程):
- [x] [PGNet](./algorithm_e2e_pgnet.md) - [x] [PGNet](./algorithm_e2e_pgnet.md)
<a name="3"></a>
## 3. 表格识别算法
已支持的表格识别算法列表(戳链接获取使用教程):
- [x] [TableMaster](./algorithm_table_master.md)
在PubTabNet表格识别公开数据集上,算法效果如下:
|模型|骨干网络|配置文件|acc|下载链接|
|---|---|---|---|---|
|TableMaster|TableResNetExtra|[configs/table/table_master.yml](../../configs/table/table_master.yml)|77.47%|[训练模型](https://paddleocr.bj.bcebos.com/ppstructure/models/tablemaster/table_structure_tablemaster_train.tar) / [推理模型](https://paddleocr.bj.bcebos.com/ppstructure/models/tablemaster/table_structure_tablemaster_infer.tar)|
# 表格识别算法-TableMASTER
- [1. 算法简介](#1-算法简介)
- [2. 环境配置](#2-环境配置)
- [3. 模型训练、评估、预测](#3-模型训练评估预测)
- [4. 推理部署](#4-推理部署)
- [4.1 Python推理](#41-python推理)
- [4.2 C++推理部署](#42-c推理部署)
- [4.3 Serving服务化部署](#43-serving服务化部署)
- [4.4 更多推理部署](#44-更多推理部署)
- [5. FAQ](#5-faq)
- [引用](#引用)
<a name="1"></a>
## 1. 算法简介
论文信息:
> [TableMaster: PINGAN-VCGROUP’S SOLUTION FOR ICDAR 2021 COMPETITION ON SCIENTIFIC LITERATURE PARSING TASK B: TABLE RECOGNITION TO HTML](https://arxiv.org/pdf/2105.01848.pdf)
> Ye, Jiaquan and Qi, Xianbiao and He, Yelin and Chen, Yihao and Gu, Dengyi and Gao, Peng and Xiao, Rong
> 2021
在PubTabNet表格识别公开数据集上,算法复现效果如下:
|模型|骨干网络|配置文件|acc|下载链接|
| --- | --- | --- | --- | --- |
|TableMaster|TableResNetExtra|[configs/table/table_master.yml](../../configs/table/table_master.yml)|77.47%|[训练模型](https://paddleocr.bj.bcebos.com/ppstructure/models/tablemaster/table_structure_tablemaster_train.tar)/[推理模型](https://paddleocr.bj.bcebos.com/ppstructure/models/tablemaster/table_structure_tablemaster_infer.tar)|
<a name="2"></a>
## 2. 环境配置
请先参考[《运行环境准备》](./environment.md)配置PaddleOCR运行环境,参考[《项目克隆》](./clone.md)克隆项目代码。
<a name="3"></a>
## 3. 模型训练、评估、预测
上述TableMaster模型使用PubTabNet表格识别公开数据集训练得到,数据集下载可参考 [table_datasets](./dataset/table_datasets.md)
数据下载完成后,请参考[文本识别教程](./recognition.md)进行训练。PaddleOCR对代码进行了模块化,训练不同的模型只需要**更换配置文件**即可。
<a name="4"></a>
## 4. 推理部署
<a name="4-1"></a>
### 4.1 Python推理
首先将训练得到best模型,转换成inference model。以基于TableResNetExtra骨干网络,在PubTabNet数据集训练的模型为例([模型下载地址](https://paddleocr.bj.bcebos.com/contribution/table_master.tar)),可以使用如下命令进行转换:
```shell
# 注意将pretrained_model的路径设置为本地路径。
python3 tools/export_model.py -c configs/table/table_master.yml -o Global.pretrained_model=output/table_master/best_accuracy Global.save_inference_dir=./inference/table_master
```
**注意:**
- 如果您是在自己的数据集上训练的模型,并且调整了字典文件,请注意修改配置文件中的`character_dict_path`是否为所正确的字典文件。
转换成功后,在目录下有三个文件:
```
./inference/table_master/
├── inference.pdiparams # 识别inference模型的参数文件
├── inference.pdiparams.info # 识别inference模型的参数信息,可忽略
└── inference.pdmodel # 识别inference模型的program文件
```
执行如下命令进行模型推理:
```shell
cd ppstructure/
python3.7 table/predict_structure.py --table_model_dir=../output/table_master/table_structure_tablemaster_infer/ --table_algorithm=TableMaster --table_char_dict_path=../ppocr/utils/dict/table_master_structure_dict.txt --table_max_len=480 --image_dir=docs/table/table.jpg
# 预测文件夹下所有图像时,可修改image_dir为文件夹,如 --image_dir='docs/table'。
```
执行命令后,上面图像的预测结果(结构信息和表格中每个单元格的坐标)会打印到屏幕上,同时会保存单元格坐标的可视化结果。示例如下:
结果如下:
```shell
[2022/06/16 13:06:54] ppocr INFO: result: ['<html>', '<body>', '<table>', '<thead>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '</thead>', '<tbody>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '</tbody>', '</table>', '</body>', '</html>'], [[72.17591094970703, 10.759100914001465, 60.29658508300781, 16.6805362701416], [161.85562133789062, 10.884308815002441, 14.9495210647583, 16.727018356323242], [277.79876708984375, 29.54340362548828, 31.490320205688477, 18.143272399902344],
...
[336.11724853515625, 280.3601989746094, 39.456939697265625, 18.121286392211914]]
[2022/06/16 13:06:54] ppocr INFO: save vis result to ./output/table.jpg
[2022/06/16 13:06:54] ppocr INFO: Predict time of docs/table/table.jpg: 17.36806297302246
```
**注意**
- TableMaster在推理时比较慢,建议使用GPU进行使用。
<a name="4-2"></a>
### 4.2 C++推理部署
由于C++预处理后处理还未支持TableMaster,所以暂未支持
<a name="4-3"></a>
### 4.3 Serving服务化部署
暂不支持
<a name="4-4"></a>
### 4.4 更多推理部署
暂不支持
<a name="5"></a>
## 5. FAQ
## 引用
```bibtex
@article{ye2021pingan,
title={PingAn-VCGroup's Solution for ICDAR 2021 Competition on Scientific Literature Parsing Task B: Table Recognition to HTML},
author={Ye, Jiaquan and Qi, Xianbiao and He, Yelin and Chen, Yihao and Gu, Dengyi and Gao, Peng and Xiao, Rong},
journal={arXiv preprint arXiv:2105.01848},
year={2021}
}
```
...@@ -7,7 +7,8 @@ ...@@ -7,7 +7,8 @@
- [1. 文本检测模型推理](#1-文本检测模型推理) - [1. 文本检测模型推理](#1-文本检测模型推理)
- [2. 文本识别模型推理](#2-文本识别模型推理) - [2. 文本识别模型推理](#2-文本识别模型推理)
- [2.1 超轻量中文识别模型推理](#21-超轻量中文识别模型推理) - [2.1 超轻量中文识别模型推理](#21-超轻量中文识别模型推理)
- [2.2 多语言模型的推理](#22-多语言模型的推理) - [2.2 英文识别模型推理](#22-英文识别模型推理)
- [2.3 多语言模型的推理](#23-多语言模型的推理)
- [3. 方向分类模型推理](#3-方向分类模型推理) - [3. 方向分类模型推理](#3-方向分类模型推理)
- [4. 文本检测、方向分类和文字识别串联推理](#4-文本检测方向分类和文字识别串联推理) - [4. 文本检测、方向分类和文字识别串联推理](#4-文本检测方向分类和文字识别串联推理)
...@@ -78,9 +79,29 @@ python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words/ch/word_4.jpg" ...@@ -78,9 +79,29 @@ python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words/ch/word_4.jpg"
Predicts of ./doc/imgs_words/ch/word_4.jpg:('实力活力', 0.9956803321838379) Predicts of ./doc/imgs_words/ch/word_4.jpg:('实力活力', 0.9956803321838379)
``` ```
<a name="英文识别模型推理"></a>
### 2.2 英文识别模型推理
英文识别模型推理,可以执行如下命令, 注意修改字典路径:
```
# 下载英文数字识别模型:
wget https://paddleocr.bj.bcebos.com/PP-OCRv3/english/en_PP-OCRv3_det_infer.tar
tar xf en_PP-OCRv3_det_infer.tar
python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words/en/word_1.png" --rec_model_dir="./en_PP-OCRv3_det_infer/" --rec_char_dict_path="ppocr/utils/en_dict.txt"
```
![](../imgs_words/en/word_1.png)
执行命令后,上图的预测结果为:
```
Predicts of ./doc/imgs_words/en/word_1.png: ('JOINT', 0.998160719871521)
```
<a name="多语言模型的推理"></a> <a name="多语言模型的推理"></a>
### 2.2 多语言模型的推理 ### 2.3 多语言模型的推理
如果您需要预测的是其他语言模型,可以在[此链接](./models_list.md#%E5%A4%9A%E8%AF%AD%E8%A8%80%E8%AF%86%E5%88%AB%E6%A8%A1%E5%9E%8B)中找到对应语言的inference模型,在使用inference模型预测时,需要通过`--rec_char_dict_path`指定使用的字典路径, 同时为了得到正确的可视化结果,需要通过 `--vis_font_path` 指定可视化的字体路径,`doc/fonts/` 路径下有默认提供的小语种字体,例如韩文识别: 如果您需要预测的是其他语言模型,可以在[此链接](./models_list.md#%E5%A4%9A%E8%AF%AD%E8%A8%80%E8%AF%86%E5%88%AB%E6%A8%A1%E5%9E%8B)中找到对应语言的inference模型,在使用inference模型预测时,需要通过`--rec_char_dict_path`指定使用的字典路径, 同时为了得到正确的可视化结果,需要通过 `--vis_font_path` 指定可视化的字体路径,`doc/fonts/` 路径下有默认提供的小语种字体,例如韩文识别:
``` ```
......
...@@ -6,5 +6,6 @@ PaddleOCR will add cutting-edge OCR algorithms and models continuously. Check ou ...@@ -6,5 +6,6 @@ PaddleOCR will add cutting-edge OCR algorithms and models continuously. Check ou
- [text detection algorithms](./algorithm_overview_en.md#11) - [text detection algorithms](./algorithm_overview_en.md#11)
- [text recognition algorithms](./algorithm_overview_en.md#12) - [text recognition algorithms](./algorithm_overview_en.md#12)
- [end-to-end algorithms](./algorithm_overview_en.md#2) - [end-to-end algorithms](./algorithm_overview_en.md#2)
- [table recognition algorithms](./algorithm_overview_en.md#3)
Developers are welcome to contribute more algorithms! Please refer to [add new algorithm](./add_new_algorithm_en.md) guideline. Developers are welcome to contribute more algorithms! Please refer to [add new algorithm](./add_new_algorithm_en.md) guideline.
\ No newline at end of file
# OCR Algorithms # OCR Algorithms
- [1. Two-stage Algorithms](#1) - [1. Two-stage Algorithms](#1)
* [1.1 Text Detection Algorithms](#11) - [1.1 Text Detection Algorithms](#11)
* [1.2 Text Recognition Algorithms](#12) - [1.2 Text Recognition Algorithms](#12)
- [2. End-to-end Algorithms](#2) - [2. End-to-end Algorithms](#2)
- [3. Table Recognition Algorithms](#3)
This tutorial lists the OCR algorithms supported by PaddleOCR, as well as the models and metrics of each algorithm on **English public datasets**. It is mainly used for algorithm introduction and algorithm performance comparison. For more models on other datasets including Chinese, please refer to [PP-OCR v2.0 models list](./models_list_en.md). This tutorial lists the OCR algorithms supported by PaddleOCR, as well as the models and metrics of each algorithm on **English public datasets**. It is mainly used for algorithm introduction and algorithm performance comparison. For more models on other datasets including Chinese, please refer to [PP-OCR v2.0 models list](./models_list_en.md).
...@@ -95,3 +96,15 @@ Refer to [DTRB](https://arxiv.org/abs/1904.01906), the training and evaluation r ...@@ -95,3 +96,15 @@ Refer to [DTRB](https://arxiv.org/abs/1904.01906), the training and evaluation r
Supported end-to-end algorithms (Click the link to get the tutorial): Supported end-to-end algorithms (Click the link to get the tutorial):
- [x] [PGNet](./algorithm_e2e_pgnet_en.md) - [x] [PGNet](./algorithm_e2e_pgnet_en.md)
<a name="3"></a>
## 3. Table Recognition Algorithms
Supported table recognition algorithms (Click the link to get the tutorial):
- [x] [TableMaster](./algorithm_table_master_en.md)
On the PubTabNet dataset, the algorithm result is as follows:
|Model|Backbone|Config|Acc|Download link|
|---|---|---|---|---|
|TableMaster|TableResNetExtra|[configs/table/table_master.yml](../../configs/table/table_master.yml)|77.47%|[trained](https://paddleocr.bj.bcebos.com/ppstructure/models/tablemaster/table_structure_tablemaster_train.tar) / [inference model](https://paddleocr.bj.bcebos.com/ppstructure/models/tablemaster/table_structure_tablemaster_infer.tar)|
# Table Recognition Algorithm-TableMASTER
- [1. Introduction](#1-introduction)
- [2. Environment](#2-environment)
- [3. Model Training / Evaluation / Prediction](#3-model-training--evaluation--prediction)
- [4. Inference and Deployment](#4-inference-and-deployment)
- [4.1 Python Inference](#41-python-inference)
- [4.2 C++ Inference](#42-c-inference)
- [4.3 Serving](#43-serving)
- [4.4 More](#44-more)
- [5. FAQ](#5-faq)
- [Citation](#citation)
<a name="1"></a>
## 1. Introduction
Paper:
> [TableMaster: PINGAN-VCGROUP’S SOLUTION FOR ICDAR 2021 COMPETITION ON SCIENTIFIC LITERATURE PARSING TASK B: TABLE RECOGNITION TO HTML](https://arxiv.org/pdf/2105.01848.pdf)
> Ye, Jiaquan and Qi, Xianbiao and He, Yelin and Chen, Yihao and Gu, Dengyi and Gao, Peng and Xiao, Rong
> 2021
On the PubTabNet table recognition public data set, the algorithm reproduction acc is as follows:
|Model|Backbone|Cnnfig|Acc|Download link|
| --- | --- | --- | --- | --- |
|TableMaster|TableResNetExtra|[configs/table/table_master.yml](../../configs/table/table_master.yml)|77.47%|[trained model](https://paddleocr.bj.bcebos.com/ppstructure/models/tablemaster/table_structure_tablemaster_train.tar)/[inference model](https://paddleocr.bj.bcebos.com/ppstructure/models/tablemaster/table_structure_tablemaster_infer.tar)|
<a name="2"></a>
## 2. Environment
Please refer to ["Environment Preparation"](./environment_en.md) to configure the PaddleOCR environment, and refer to ["Project Clone"](./clone_en.md) to clone the project code.
<a name="3"></a>
## 3. Model Training / Evaluation / Prediction
The above TableMaster model is trained using the PubTabNet table recognition public dataset. For the download of the dataset, please refer to [table_datasets](./dataset/table_datasets_en.md).
After the data download is complete, please refer to [Text Recognition Training Tutorial](./recognition_en.md) for training. PaddleOCR has modularized the code structure, so that you only need to **replace the configuration file** to train different models.
<a name="4"></a>
## 4. Inference and Deployment
<a name="4-1"></a>
### 4.1 Python Inference
First, convert the model saved in the TableMaster table recognition training process into an inference model. Taking the model based on the TableResNetExtra backbone network and trained on the PubTabNet dataset as example ([model download link](https://paddleocr.bj.bcebos.com/contribution/table_master.tar)), you can use the following command to convert:
```shell
python3 tools/export_model.py -c configs/table/table_master.yml -o Global.pretrained_model=output/table_master/best_accuracy Global.save_inference_dir=./inference/table_master
```
**Note: **
- If you trained the model on your own dataset and adjusted the dictionary file, please pay attention to whether the `character_dict_path` in the modified configuration file is the correct dictionary file
Execute the following command for model inference:
```shell
cd ppstructure/
# When predicting all images in a folder, you can modify image_dir to a folder, such as --image_dir='docs/table'.
python3.7 table/predict_structure.py --table_model_dir=../output/table_master/table_structure_tablemaster_infer/ --table_algorithm=TableMaster --table_char_dict_path=../ppocr/utils/dict/table_master_structure_dict.txt --table_max_len=480 --image_dir=docs/table/table.jpg
```
After executing the command, the prediction results of the above image (structural information and the coordinates of each cell in the table) are printed to the screen, and the visualization of the cell coordinates is also saved. An example is as follows:
result:
```shell
[2022/06/16 13:06:54] ppocr INFO: result: ['<html>', '<body>', '<table>', '<thead>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '</thead>', '<tbody>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '<tr>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '<td></td>', '</tr>', '</tbody>', '</table>', '</body>', '</html>'], [[72.17591094970703, 10.759100914001465, 60.29658508300781, 16.6805362701416], [161.85562133789062, 10.884308815002441, 14.9495210647583, 16.727018356323242], [277.79876708984375, 29.54340362548828, 31.490320205688477, 18.143272399902344],
...
[336.11724853515625, 280.3601989746094, 39.456939697265625, 18.121286392211914]]
[2022/06/16 13:06:54] ppocr INFO: save vis result to ./output/table.jpg
[2022/06/16 13:06:54] ppocr INFO: Predict time of docs/table/table.jpg: 17.36806297302246
```
**Note**:
- TableMaster is relatively slow during inference, and it is recommended to use GPU for use.
<a name="4-2"></a>
### 4.2 C++ Inference
Since the post-processing is not written in CPP, the TableMaster does not support CPP inference.
<a name="4-3"></a>
### 4.3 Serving
Not supported
<a name="4-4"></a>
### 4.4 More
Not supported
<a name="5"></a>
## 5. FAQ
## Citation
```bibtex
@article{ye2021pingan,
title={PingAn-VCGroup's Solution for ICDAR 2021 Competition on Scientific Literature Parsing Task B: Table Recognition to HTML},
author={Ye, Jiaquan and Qi, Xianbiao and He, Yelin and Chen, Yihao and Gu, Dengyi and Gao, Peng and Xiao, Rong},
journal={arXiv preprint arXiv:2105.01848},
year={2021}
}
```
...@@ -8,7 +8,8 @@ This article introduces the use of the Python inference engine for the PP-OCR mo ...@@ -8,7 +8,8 @@ This article introduces the use of the Python inference engine for the PP-OCR mo
- [Text Detection Model Inference](#text-detection-model-inference) - [Text Detection Model Inference](#text-detection-model-inference)
- [Text Recognition Model Inference](#text-recognition-model-inference) - [Text Recognition Model Inference](#text-recognition-model-inference)
- [1. Lightweight Chinese Recognition Model Inference](#1-lightweight-chinese-recognition-model-inference) - [1. Lightweight Chinese Recognition Model Inference](#1-lightweight-chinese-recognition-model-inference)
- [2. Multilingual Model Inference](#2-multilingual-model-inference) - [2. English Recognition Model Inference](#2-english-recognition-model-inference)
- [3. Multilingual Model Inference](#3-multilingual-model-inference)
- [Angle Classification Model Inference](#angle-classification-model-inference) - [Angle Classification Model Inference](#angle-classification-model-inference)
- [Text Detection Angle Classification and Recognition Inference Concatenation](#text-detection-angle-classification-and-recognition-inference-concatenation) - [Text Detection Angle Classification and Recognition Inference Concatenation](#text-detection-angle-classification-and-recognition-inference-concatenation)
...@@ -76,10 +77,31 @@ After executing the command, the prediction results (recognized text and score) ...@@ -76,10 +77,31 @@ After executing the command, the prediction results (recognized text and score)
```bash ```bash
Predicts of ./doc/imgs_words_en/word_10.png:('PAIN', 0.988671) Predicts of ./doc/imgs_words_en/word_10.png:('PAIN', 0.988671)
``` ```
<a name="2-english-recognition-model-inference"></a>
### 2. English Recognition Model Inference
<a name="MULTILINGUAL_MODEL_INFERENCE"></a> For English recognition model inference, you can execute the following commands,you need to specify the dictionary path used by `--rec_char_dict_path`:
### 2. Multilingual Model Inference ```
# download en model:
wget https://paddleocr.bj.bcebos.com/PP-OCRv3/english/en_PP-OCRv3_det_infer.tar
tar xf en_PP-OCRv3_det_infer.tar
python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words/en/word_1.png" --rec_model_dir="./en_PP-OCRv3_det_infer/" --rec_char_dict_path="ppocr/utils/en_dict.txt"
```
![](../imgs_words/en/word_1.png)
After executing the command, the prediction result of the above figure is:
```
Predicts of ./doc/imgs_words/en/word_1.png: ('JOINT', 0.998160719871521)
```
<a name="3-multilingual-model-inference"></a>
### 3. Multilingual Model Inference
If you need to predict [other language models](./models_list_en.md#Multilingual), when using inference model prediction, you need to specify the dictionary path used by `--rec_char_dict_path`. At the same time, in order to get the correct visualization results, If you need to predict [other language models](./models_list_en.md#Multilingual), when using inference model prediction, you need to specify the dictionary path used by `--rec_char_dict_path`. At the same time, in order to get the correct visualization results,
You need to specify the visual font path through `--vis_font_path`. There are small language fonts provided by default under the `doc/fonts` path, such as Korean recognition: You need to specify the visual font path through `--vis_font_path`. There are small language fonts provided by default under the `doc/fonts` path, such as Korean recognition:
......
...@@ -37,7 +37,7 @@ from .label_ops import * ...@@ -37,7 +37,7 @@ from .label_ops import *
from .east_process import * from .east_process import *
from .sast_process import * from .sast_process import *
from .pg_process import * from .pg_process import *
from .gen_table_mask import * from .table_ops import *
from .vqa import * from .vqa import *
......
...@@ -562,171 +562,210 @@ class SRNLabelEncode(BaseRecLabelEncode): ...@@ -562,171 +562,210 @@ class SRNLabelEncode(BaseRecLabelEncode):
return idx return idx
class TableLabelEncode(object): class TableLabelEncode(AttnLabelEncode):
""" Convert between text-label and text-index """ """ Convert between text-label and text-index """
def __init__(self, def __init__(self,
max_text_length, max_text_length,
max_elem_length,
max_cell_num,
character_dict_path, character_dict_path,
span_weight=1.0, replace_empty_cell_token=False,
merge_no_span_structure=False,
learn_empty_box=False,
point_num=2,
**kwargs): **kwargs):
self.max_text_length = max_text_length self.max_text_len = max_text_length
self.max_elem_length = max_elem_length self.lower = False
self.max_cell_num = max_cell_num self.learn_empty_box = learn_empty_box
list_character, list_elem = self.load_char_elem_dict( self.merge_no_span_structure = merge_no_span_structure
character_dict_path) self.replace_empty_cell_token = replace_empty_cell_token
list_character = self.add_special_char(list_character)
list_elem = self.add_special_char(list_elem) dict_character = []
self.dict_character = {}
for i, char in enumerate(list_character):
self.dict_character[char] = i
self.dict_elem = {}
for i, elem in enumerate(list_elem):
self.dict_elem[elem] = i
self.span_weight = span_weight
def load_char_elem_dict(self, character_dict_path):
list_character = []
list_elem = []
with open(character_dict_path, "rb") as fin: with open(character_dict_path, "rb") as fin:
lines = fin.readlines() lines = fin.readlines()
substr = lines[0].decode('utf-8').strip("\r\n").split("\t") for line in lines:
character_num = int(substr[0]) line = line.decode('utf-8').strip("\n").strip("\r\n")
elem_num = int(substr[1]) dict_character.append(line)
for cno in range(1, 1 + character_num):
character = lines[cno].decode('utf-8').strip("\r\n")
list_character.append(character)
for eno in range(1 + character_num, 1 + character_num + elem_num):
elem = lines[eno].decode('utf-8').strip("\r\n")
list_elem.append(elem)
return list_character, list_elem
def add_special_char(self, list_character):
self.beg_str = "sos"
self.end_str = "eos"
list_character = [self.beg_str] + list_character + [self.end_str]
return list_character
def get_span_idx_list(self): dict_character = self.add_special_char(dict_character)
span_idx_list = [] self.dict = {}
for elem in self.dict_elem: for i, char in enumerate(dict_character):
if 'span' in elem: self.dict[char] = i
span_idx_list.append(self.dict_elem[elem]) self.idx2char = {v: k for k, v in self.dict.items()}
return span_idx_list
self.character = dict_character
self.point_num = point_num
self.pad_idx = self.dict[self.beg_str]
self.start_idx = self.dict[self.beg_str]
self.end_idx = self.dict[self.end_str]
self.td_token = ['<td>', '<td', '<eb></eb>', '<td></td>']
self.empty_bbox_token_dict = {
"[]": '<eb></eb>',
"[' ']": '<eb1></eb1>',
"['<b>', ' ', '</b>']": '<eb2></eb2>',
"['\\u2028', '\\u2028']": '<eb3></eb3>',
"['<sup>', ' ', '</sup>']": '<eb4></eb4>',
"['<b>', '</b>']": '<eb5></eb5>',
"['<i>', ' ', '</i>']": '<eb6></eb6>',
"['<b>', '<i>', '</i>', '</b>']": '<eb7></eb7>',
"['<b>', '<i>', ' ', '</i>', '</b>']": '<eb8></eb8>',
"['<i>', '</i>']": '<eb9></eb9>',
"['<b>', ' ', '\\u2028', ' ', '\\u2028', ' ', '</b>']":
'<eb10></eb10>',
}
@property
def _max_text_len(self):
return self.max_text_len + 2
def __call__(self, data): def __call__(self, data):
cells = data['cells'] cells = data['cells']
structure = data['structure']['tokens'] structure = data['structure']
structure = self.encode(structure, 'elem') if self.merge_no_span_structure:
structure = self._merge_no_span_structure(structure)
if self.replace_empty_cell_token:
structure = self._replace_empty_cell_token(structure, cells)
# remove empty token and add " " to span token
new_structure = []
for token in structure:
if token != '':
if 'span' in token and token[0] != ' ':
token = ' ' + token
new_structure.append(token)
# encode structure
structure = self.encode(new_structure)
if structure is None: if structure is None:
return None return None
elem_num = len(structure)
structure = [0] + structure + [len(self.dict_elem) - 1] structure = [self.start_idx] + structure + [self.end_idx
structure = structure + [0] * (self.max_elem_length + 2 - len(structure) ] # add sos abd eos
) structure = structure + [self.pad_idx] * (self._max_text_len -
len(structure)) # pad
structure = np.array(structure) structure = np.array(structure)
data['structure'] = structure data['structure'] = structure
elem_char_idx1 = self.dict_elem['<td>']
elem_char_idx2 = self.dict_elem['<td'] if len(structure) > self._max_text_len:
span_idx_list = self.get_span_idx_list() return None
td_idx_list = np.logical_or(structure == elem_char_idx1,
structure == elem_char_idx2) # encode box
td_idx_list = np.where(td_idx_list)[0] bboxes = np.zeros(
(self._max_text_len, self.point_num * 2), dtype=np.float32)
structure_mask = np.ones( bbox_masks = np.zeros((self._max_text_len, 1), dtype=np.float32)
(self.max_elem_length + 2, 1), dtype=np.float32)
bbox_list = np.zeros((self.max_elem_length + 2, 4), dtype=np.float32) bbox_idx = 0
bbox_list_mask = np.zeros(
(self.max_elem_length + 2, 1), dtype=np.float32) for i, token in enumerate(structure):
img_height, img_width, img_ch = data['image'].shape if self.idx2char[token] in self.td_token:
if len(span_idx_list) > 0: if 'bbox' in cells[bbox_idx] and len(cells[bbox_idx][
span_weight = len(td_idx_list) * 1.0 / len(span_idx_list) 'tokens']) > 0:
span_weight = min(max(span_weight, 1.0), self.span_weight) bbox = cells[bbox_idx]['bbox'].copy()
for cno in range(len(cells)): bbox = np.array(bbox, dtype=np.float32).reshape(-1)
if 'bbox' in cells[cno]: bboxes[i] = bbox
bbox = cells[cno]['bbox'].copy() bbox_masks[i] = 1.0
bbox[0] = bbox[0] * 1.0 / img_width if self.learn_empty_box:
bbox[1] = bbox[1] * 1.0 / img_height bbox_masks[i] = 1.0
bbox[2] = bbox[2] * 1.0 / img_width bbox_idx += 1
bbox[3] = bbox[3] * 1.0 / img_height data['bboxes'] = bboxes
td_idx = td_idx_list[cno] data['bbox_masks'] = bbox_masks
bbox_list[td_idx] = bbox
bbox_list_mask[td_idx] = 1.0
cand_span_idx = td_idx + 1
if cand_span_idx < (self.max_elem_length + 2):
if structure[cand_span_idx] in span_idx_list:
structure_mask[cand_span_idx] = span_weight
data['bbox_list'] = bbox_list
data['bbox_list_mask'] = bbox_list_mask
data['structure_mask'] = structure_mask
char_beg_idx = self.get_beg_end_flag_idx('beg', 'char')
char_end_idx = self.get_beg_end_flag_idx('end', 'char')
elem_beg_idx = self.get_beg_end_flag_idx('beg', 'elem')
elem_end_idx = self.get_beg_end_flag_idx('end', 'elem')
data['sp_tokens'] = np.array([
char_beg_idx, char_end_idx, elem_beg_idx, elem_end_idx,
elem_char_idx1, elem_char_idx2, self.max_text_length,
self.max_elem_length, self.max_cell_num, elem_num
])
return data return data
def encode(self, text, char_or_elem): def _merge_no_span_structure(self, structure):
"""convert text-label into text-index.
""" """
if char_or_elem == "char": This code is refer from:
max_len = self.max_text_length https://github.com/JiaquanYe/TableMASTER-mmocr/blob/master/table_recognition/data_preprocess.py
current_dict = self.dict_character """
else: new_structure = []
max_len = self.max_elem_length i = 0
current_dict = self.dict_elem while i < len(structure):
if len(text) > max_len: token = structure[i]
return None if token == '<td>':
if len(text) == 0: token = '<td></td>'
if char_or_elem == "char": i += 1
return [self.dict_character['space']] new_structure.append(token)
else: i += 1
return None return new_structure
text_list = []
for char in text: def _replace_empty_cell_token(self, token_list, cells):
if char not in current_dict: """
return None This fun code is refer from:
text_list.append(current_dict[char]) https://github.com/JiaquanYe/TableMASTER-mmocr/blob/master/table_recognition/data_preprocess.py
if len(text_list) == 0: """
if char_or_elem == "char":
return [self.dict_character['space']] bbox_idx = 0
add_empty_bbox_token_list = []
for token in token_list:
if token in ['<td></td>', '<td', '<td>']:
if 'bbox' not in cells[bbox_idx].keys():
content = str(cells[bbox_idx]['tokens'])
token = self.empty_bbox_token_dict[content]
add_empty_bbox_token_list.append(token)
bbox_idx += 1
else: else:
return None add_empty_bbox_token_list.append(token)
return text_list return add_empty_bbox_token_list
def get_ignored_tokens(self, char_or_elem):
beg_idx = self.get_beg_end_flag_idx("beg", char_or_elem)
end_idx = self.get_beg_end_flag_idx("end", char_or_elem)
return [beg_idx, end_idx]
def get_beg_end_flag_idx(self, beg_or_end, char_or_elem): class TableMasterLabelEncode(TableLabelEncode):
if char_or_elem == "char": """ Convert between text-label and text-index """
if beg_or_end == "beg":
idx = np.array(self.dict_character[self.beg_str]) def __init__(self,
elif beg_or_end == "end": max_text_length,
idx = np.array(self.dict_character[self.end_str]) character_dict_path,
else: replace_empty_cell_token=False,
assert False, "Unsupport type %s in get_beg_end_flag_idx of char" \ merge_no_span_structure=False,
% beg_or_end learn_empty_box=False,
elif char_or_elem == "elem": point_num=2,
if beg_or_end == "beg": **kwargs):
idx = np.array(self.dict_elem[self.beg_str]) super(TableMasterLabelEncode, self).__init__(
elif beg_or_end == "end": max_text_length, character_dict_path, replace_empty_cell_token,
idx = np.array(self.dict_elem[self.end_str]) merge_no_span_structure, learn_empty_box, point_num, **kwargs)
else: self.pad_idx = self.dict[self.pad_str]
assert False, "Unsupport type %s in get_beg_end_flag_idx of elem" \ self.unknown_idx = self.dict[self.unknown_str]
% beg_or_end
else: @property
assert False, "Unsupport type %s in char_or_elem" \ def _max_text_len(self):
% char_or_elem return self.max_text_len
return idx
def add_special_char(self, dict_character):
self.beg_str = '<SOS>'
self.end_str = '<EOS>'
self.unknown_str = '<UKN>'
self.pad_str = '<PAD>'
dict_character = dict_character
dict_character = dict_character + [
self.unknown_str, self.beg_str, self.end_str, self.pad_str
]
return dict_character
class TableBoxEncode(object):
def __init__(self, use_xywh=False, **kwargs):
self.use_xywh = use_xywh
def __call__(self, data):
img_height, img_width = data['image'].shape[:2]
bboxes = data['bboxes']
if self.use_xywh and bboxes.shape[1] == 4:
bboxes = self.xyxy2xywh(bboxes)
bboxes[:, 0::2] /= img_width
bboxes[:, 1::2] /= img_height
data['bboxes'] = bboxes
return data
def xyxy2xywh(self, bboxes):
"""
Convert coord (x1,y1,x2,y2) to (x,y,w,h).
where (x1,y1) is top-left, (x2,y2) is bottom-right.
(x,y) is bbox center and (w,h) is width and height.
:param bboxes: (x1, y1, x2, y2)
:return:
"""
new_bboxes = np.empty_like(bboxes)
new_bboxes[:, 0] = (bboxes[:, 0] + bboxes[:, 2]) / 2 # x center
new_bboxes[:, 1] = (bboxes[:, 1] + bboxes[:, 3]) / 2 # y center
new_bboxes[:, 2] = bboxes[:, 2] - bboxes[:, 0] # width
new_bboxes[:, 3] = bboxes[:, 3] - bboxes[:, 1] # height
return new_bboxes
class SARLabelEncode(BaseRecLabelEncode): class SARLabelEncode(BaseRecLabelEncode):
...@@ -830,6 +869,7 @@ class VQATokenLabelEncode(object): ...@@ -830,6 +869,7 @@ class VQATokenLabelEncode(object):
contains_re=False, contains_re=False,
add_special_ids=False, add_special_ids=False,
algorithm='LayoutXLM', algorithm='LayoutXLM',
use_textline_bbox_info=True,
infer_mode=False, infer_mode=False,
ocr_engine=None, ocr_engine=None,
**kwargs): **kwargs):
...@@ -858,11 +898,51 @@ class VQATokenLabelEncode(object): ...@@ -858,11 +898,51 @@ class VQATokenLabelEncode(object):
self.add_special_ids = add_special_ids self.add_special_ids = add_special_ids
self.infer_mode = infer_mode self.infer_mode = infer_mode
self.ocr_engine = ocr_engine self.ocr_engine = ocr_engine
self.use_textline_bbox_info = use_textline_bbox_info
def split_bbox(self, bbox, text, tokenizer):
words = text.split()
token_bboxes = []
curr_word_idx = 0
x1, y1, x2, y2 = bbox
unit_w = (x2 - x1) / len(text)
for idx, word in enumerate(words):
curr_w = len(word) * unit_w
word_bbox = [x1, y1, x1 + curr_w, y2]
token_bboxes.extend([word_bbox] * len(tokenizer.tokenize(word)))
x1 += (len(word) + 1) * unit_w
return token_bboxes
def filter_empty_contents(self, ocr_info):
"""
find out the empty texts and remove the links
"""
new_ocr_info = []
empty_index = []
for idx, info in enumerate(ocr_info):
if len(info["transcription"]) > 0:
new_ocr_info.append(copy.deepcopy(info))
else:
empty_index.append(info["id"])
for idx, info in enumerate(new_ocr_info):
new_link = []
for link in info["linking"]:
if link[0] in empty_index or link[1] in empty_index:
continue
new_link.append(link)
new_ocr_info[idx]["linking"] = new_link
return new_ocr_info
def __call__(self, data): def __call__(self, data):
# load bbox and label info # load bbox and label info
ocr_info = self._load_ocr_info(data) ocr_info = self._load_ocr_info(data)
# for re
train_re = self.contains_re and not self.infer_mode
if train_re:
ocr_info = self.filter_empty_contents(ocr_info)
height, width, _ = data['image'].shape height, width, _ = data['image'].shape
words_list = [] words_list = []
...@@ -874,8 +954,6 @@ class VQATokenLabelEncode(object): ...@@ -874,8 +954,6 @@ class VQATokenLabelEncode(object):
entities = [] entities = []
# for re
train_re = self.contains_re and not self.infer_mode
if train_re: if train_re:
relations = [] relations = []
id2label = {} id2label = {}
...@@ -885,18 +963,19 @@ class VQATokenLabelEncode(object): ...@@ -885,18 +963,19 @@ class VQATokenLabelEncode(object):
data['ocr_info'] = copy.deepcopy(ocr_info) data['ocr_info'] = copy.deepcopy(ocr_info)
for info in ocr_info: for info in ocr_info:
text = info["transcription"]
if len(text) <= 0:
continue
if train_re: if train_re:
# for re # for re
if len(info["transcription"]) == 0: if len(text) == 0:
empty_entity.add(info["id"]) empty_entity.add(info["id"])
continue continue
id2label[info["id"]] = info["label"] id2label[info["id"]] = info["label"]
relations.extend([tuple(sorted(l)) for l in info["linking"]]) relations.extend([tuple(sorted(l)) for l in info["linking"]])
# smooth_box # smooth_box
info["bbox"] = self.trans_poly_to_bbox(info["points"]) info["bbox"] = self.trans_poly_to_bbox(info["points"])
bbox = self._smooth_box(info["bbox"], height, width)
text = info["transcription"]
encode_res = self.tokenizer.encode( encode_res = self.tokenizer.encode(
text, pad_to_max_seq_len=False, return_attention_mask=True) text, pad_to_max_seq_len=False, return_attention_mask=True)
...@@ -907,6 +986,19 @@ class VQATokenLabelEncode(object): ...@@ -907,6 +986,19 @@ class VQATokenLabelEncode(object):
-1] -1]
encode_res["attention_mask"] = encode_res["attention_mask"][1: encode_res["attention_mask"] = encode_res["attention_mask"][1:
-1] -1]
if self.use_textline_bbox_info:
bbox = [info["bbox"]] * len(encode_res["input_ids"])
else:
bbox = self.split_bbox(info["bbox"], info["transcription"],
self.tokenizer)
if len(bbox) <= 0:
continue
bbox = self._smooth_box(bbox, height, width)
if self.add_special_ids:
bbox.insert(0, [0, 0, 0, 0])
bbox.append([0, 0, 0, 0])
# parse label # parse label
if not self.infer_mode: if not self.infer_mode:
label = info['label'] label = info['label']
...@@ -931,7 +1023,7 @@ class VQATokenLabelEncode(object): ...@@ -931,7 +1023,7 @@ class VQATokenLabelEncode(object):
}) })
input_ids_list.extend(encode_res["input_ids"]) input_ids_list.extend(encode_res["input_ids"])
token_type_ids_list.extend(encode_res["token_type_ids"]) token_type_ids_list.extend(encode_res["token_type_ids"])
bbox_list.extend([bbox] * len(encode_res["input_ids"])) bbox_list.extend(bbox)
words_list.append(text) words_list.append(text)
segment_offset_id.append(len(input_ids_list)) segment_offset_id.append(len(input_ids_list))
if not self.infer_mode: if not self.infer_mode:
...@@ -980,12 +1072,14 @@ class VQATokenLabelEncode(object): ...@@ -980,12 +1072,14 @@ class VQATokenLabelEncode(object):
info_dict = json.loads(info) info_dict = json.loads(info)
return info_dict return info_dict
def _smooth_box(self, bbox, height, width): def _smooth_box(self, bboxes, height, width):
bbox[0] = int(bbox[0] * 1000.0 / width) bboxes = np.array(bboxes)
bbox[2] = int(bbox[2] * 1000.0 / width) bboxes[:, 0] = bboxes[:, 0] * 1000 / width
bbox[1] = int(bbox[1] * 1000.0 / height) bboxes[:, 2] = bboxes[:, 2] * 1000 / width
bbox[3] = int(bbox[3] * 1000.0 / height) bboxes[:, 1] = bboxes[:, 1] * 1000 / height
return bbox bboxes[:, 3] = bboxes[:, 3] * 1000 / height
bboxes = bboxes.astype("int64").tolist()
return bboxes
def _parse_label(self, label, encode_res): def _parse_label(self, label, encode_res):
gt_label = [] gt_label = []
...@@ -1013,7 +1107,6 @@ class MultiLabelEncode(BaseRecLabelEncode): ...@@ -1013,7 +1107,6 @@ class MultiLabelEncode(BaseRecLabelEncode):
use_space_char, **kwargs) use_space_char, **kwargs)
def __call__(self, data): def __call__(self, data):
data_ctc = copy.deepcopy(data) data_ctc = copy.deepcopy(data)
data_sar = copy.deepcopy(data) data_sar = copy.deepcopy(data)
data_out = dict() data_out = dict()
......
...@@ -32,7 +32,7 @@ class GenTableMask(object): ...@@ -32,7 +32,7 @@ class GenTableMask(object):
self.shrink_h_max = 5 self.shrink_h_max = 5
self.shrink_w_max = 5 self.shrink_w_max = 5
self.mask_type = mask_type self.mask_type = mask_type
def projection(self, erosion, h, w, spilt_threshold=0): def projection(self, erosion, h, w, spilt_threshold=0):
# 水平投影 # 水平投影
projection_map = np.ones_like(erosion) projection_map = np.ones_like(erosion)
...@@ -48,10 +48,12 @@ class GenTableMask(object): ...@@ -48,10 +48,12 @@ class GenTableMask(object):
in_text = False # 是否遍历到了字符区内 in_text = False # 是否遍历到了字符区内
box_list = [] box_list = []
for i in range(len(project_val_array)): for i in range(len(project_val_array)):
if in_text == False and project_val_array[i] > spilt_threshold: # 进入字符区了 if in_text == False and project_val_array[
i] > spilt_threshold: # 进入字符区了
in_text = True in_text = True
start_idx = i start_idx = i
elif project_val_array[i] <= spilt_threshold and in_text == True: # 进入空白区了 elif project_val_array[
i] <= spilt_threshold and in_text == True: # 进入空白区了
end_idx = i end_idx = i
in_text = False in_text = False
if end_idx - start_idx <= 2: if end_idx - start_idx <= 2:
...@@ -70,7 +72,8 @@ class GenTableMask(object): ...@@ -70,7 +72,8 @@ class GenTableMask(object):
box_gray_img = cv2.cvtColor(box_img, cv2.COLOR_BGR2GRAY) box_gray_img = cv2.cvtColor(box_img, cv2.COLOR_BGR2GRAY)
h, w = box_gray_img.shape h, w = box_gray_img.shape
# 灰度图片进行二值化处理 # 灰度图片进行二值化处理
ret, thresh1 = cv2.threshold(box_gray_img, 200, 255, cv2.THRESH_BINARY_INV) ret, thresh1 = cv2.threshold(box_gray_img, 200, 255,
cv2.THRESH_BINARY_INV)
# 纵向腐蚀 # 纵向腐蚀
if h < w: if h < w:
kernel = np.ones((2, 1), np.uint8) kernel = np.ones((2, 1), np.uint8)
...@@ -95,10 +98,12 @@ class GenTableMask(object): ...@@ -95,10 +98,12 @@ class GenTableMask(object):
box_list = [] box_list = []
spilt_threshold = 0 spilt_threshold = 0
for i in range(len(project_val_array)): for i in range(len(project_val_array)):
if in_text == False and project_val_array[i] > spilt_threshold: # 进入字符区了 if in_text == False and project_val_array[
i] > spilt_threshold: # 进入字符区了
in_text = True in_text = True
start_idx = i start_idx = i
elif project_val_array[i] <= spilt_threshold and in_text == True: # 进入空白区了 elif project_val_array[
i] <= spilt_threshold and in_text == True: # 进入空白区了
end_idx = i end_idx = i
in_text = False in_text = False
if end_idx - start_idx <= 2: if end_idx - start_idx <= 2:
...@@ -120,7 +125,8 @@ class GenTableMask(object): ...@@ -120,7 +125,8 @@ class GenTableMask(object):
h_end = h h_end = h
word_img = erosion[h_start:h_end + 1, :] word_img = erosion[h_start:h_end + 1, :]
word_h, word_w = word_img.shape word_h, word_w = word_img.shape
w_split_list, w_projection_map = self.projection(word_img.T, word_w, word_h) w_split_list, w_projection_map = self.projection(word_img.T,
word_w, word_h)
w_start, w_end = w_split_list[0][0], w_split_list[-1][1] w_start, w_end = w_split_list[0][0], w_split_list[-1][1]
if h_start > 0: if h_start > 0:
h_start -= 1 h_start -= 1
...@@ -170,75 +176,54 @@ class GenTableMask(object): ...@@ -170,75 +176,54 @@ class GenTableMask(object):
for sno in range(len(split_bbox_list)): for sno in range(len(split_bbox_list)):
left, top, right, bottom = split_bbox_list[sno] left, top, right, bottom = split_bbox_list[sno]
left, top, right, bottom = self.shrink_bbox([left, top, right, bottom]) left, top, right, bottom = self.shrink_bbox(
[left, top, right, bottom])
if self.mask_type == 1: if self.mask_type == 1:
mask_img[top:bottom, left:right] = 1.0 mask_img[top:bottom, left:right] = 1.0
data['mask_img'] = mask_img data['mask_img'] = mask_img
else: else:
mask_img[top:bottom, left:right, :] = (255, 255, 255) mask_img[top:bottom, left:right, :] = (255, 255, 255)
data['image'] = mask_img data['image'] = mask_img
return data return data
class ResizeTableImage(object): class ResizeTableImage(object):
def __init__(self, max_len, **kwargs): def __init__(self, max_len, resize_bboxes=False, infer_mode=False,
**kwargs):
super(ResizeTableImage, self).__init__() super(ResizeTableImage, self).__init__()
self.max_len = max_len self.max_len = max_len
self.resize_bboxes = resize_bboxes
self.infer_mode = infer_mode
def get_img_bbox(self, cells): def __call__(self, data):
bbox_list = [] img = data['image']
if len(cells) == 0:
return bbox_list
cell_num = len(cells)
for cno in range(cell_num):
if "bbox" in cells[cno]:
bbox = cells[cno]['bbox']
bbox_list.append(bbox)
return bbox_list
def resize_img_table(self, img, bbox_list, max_len):
height, width = img.shape[0:2] height, width = img.shape[0:2]
ratio = max_len / (max(height, width) * 1.0) ratio = self.max_len / (max(height, width) * 1.0)
resize_h = int(height * ratio) resize_h = int(height * ratio)
resize_w = int(width * ratio) resize_w = int(width * ratio)
img_new = cv2.resize(img, (resize_w, resize_h)) resize_img = cv2.resize(img, (resize_w, resize_h))
bbox_list_new = [] if self.resize_bboxes and not self.infer_mode:
for bno in range(len(bbox_list)): data['bboxes'] = data['bboxes'] * ratio
left, top, right, bottom = bbox_list[bno].copy() data['image'] = resize_img
left = int(left * ratio) data['src_img'] = img
top = int(top * ratio) data['shape'] = np.array([resize_h, resize_w, ratio, ratio])
right = int(right * ratio)
bottom = int(bottom * ratio)
bbox_list_new.append([left, top, right, bottom])
return img_new, bbox_list_new
def __call__(self, data):
img = data['image']
if 'cells' not in data:
cells = []
else:
cells = data['cells']
bbox_list = self.get_img_bbox(cells)
img_new, bbox_list_new = self.resize_img_table(img, bbox_list, self.max_len)
data['image'] = img_new
cell_num = len(cells)
bno = 0
for cno in range(cell_num):
if "bbox" in data['cells'][cno]:
data['cells'][cno]['bbox'] = bbox_list_new[bno]
bno += 1
data['max_len'] = self.max_len data['max_len'] = self.max_len
return data return data
class PaddingTableImage(object): class PaddingTableImage(object):
def __init__(self, **kwargs): def __init__(self, size, **kwargs):
super(PaddingTableImage, self).__init__() super(PaddingTableImage, self).__init__()
self.size = size
def __call__(self, data): def __call__(self, data):
img = data['image'] img = data['image']
max_len = data['max_len'] pad_h, pad_w = self.size
padding_img = np.zeros((max_len, max_len, 3), dtype=np.float32) padding_img = np.zeros((pad_h, pad_w, 3), dtype=np.float32)
height, width = img.shape[0:2] height, width = img.shape[0:2]
padding_img[0:height, 0:width, :] = img.copy() padding_img[0:height, 0:width, :] = img.copy()
data['image'] = padding_img data['image'] = padding_img
shape = data['shape'].tolist()
shape.extend([pad_h, pad_w])
data['shape'] = np.array(shape)
return data return data
\ No newline at end of file
...@@ -16,6 +16,7 @@ import os ...@@ -16,6 +16,7 @@ import os
import random import random
from paddle.io import Dataset from paddle.io import Dataset
import json import json
from copy import deepcopy
from .imaug import transform, create_operators from .imaug import transform, create_operators
...@@ -29,33 +30,63 @@ class PubTabDataSet(Dataset): ...@@ -29,33 +30,63 @@ class PubTabDataSet(Dataset):
dataset_config = config[mode]['dataset'] dataset_config = config[mode]['dataset']
loader_config = config[mode]['loader'] loader_config = config[mode]['loader']
label_file_path = dataset_config.pop('label_file_path') label_file_list = dataset_config.pop('label_file_list')
data_source_num = len(label_file_list)
ratio_list = dataset_config.get("ratio_list", [1.0])
if isinstance(ratio_list, (float, int)):
ratio_list = [float(ratio_list)] * int(data_source_num)
assert len(
ratio_list
) == data_source_num, "The length of ratio_list should be the same as the file_list."
self.data_dir = dataset_config['data_dir'] self.data_dir = dataset_config['data_dir']
self.do_shuffle = loader_config['shuffle'] self.do_shuffle = loader_config['shuffle']
self.do_hard_select = False
if 'hard_select' in loader_config:
self.do_hard_select = loader_config['hard_select']
self.hard_prob = loader_config['hard_prob']
if self.do_hard_select:
self.img_select_prob = self.load_hard_select_prob()
self.table_select_type = None
if 'table_select_type' in loader_config:
self.table_select_type = loader_config['table_select_type']
self.table_select_prob = loader_config['table_select_prob']
self.seed = seed self.seed = seed
logger.info("Initialize indexs of datasets:%s" % label_file_path) self.mode = mode.lower()
with open(label_file_path, "rb") as f: logger.info("Initialize indexs of datasets:%s" % label_file_list)
self.data_lines = f.readlines() self.data_lines = self.get_image_info_list(label_file_list, ratio_list)
self.data_idx_order_list = list(range(len(self.data_lines))) # self.check(config['Global']['max_text_length'])
if mode.lower() == "train":
if mode.lower() == "train" and self.do_shuffle:
self.shuffle_data_random() self.shuffle_data_random()
self.ops = create_operators(dataset_config['transforms'], global_config) self.ops = create_operators(dataset_config['transforms'], global_config)
ratio_list = dataset_config.get("ratio_list", [1.0])
self.need_reset = True in [x < 1 for x in ratio_list] self.need_reset = True in [x < 1 for x in ratio_list]
def get_image_info_list(self, file_list, ratio_list):
if isinstance(file_list, str):
file_list = [file_list]
data_lines = []
for idx, file in enumerate(file_list):
with open(file, "rb") as f:
lines = f.readlines()
if self.mode == "train" or ratio_list[idx] < 1.0:
random.seed(self.seed)
lines = random.sample(lines,
round(len(lines) * ratio_list[idx]))
data_lines.extend(lines)
return data_lines
def check(self, max_text_length):
data_lines = []
for line in self.data_lines:
data_line = line.decode('utf-8').strip("\n")
info = json.loads(data_line)
file_name = info['filename']
cells = info['html']['cells'].copy()
structure = info['html']['structure']['tokens'].copy()
img_path = os.path.join(self.data_dir, file_name)
if not os.path.exists(img_path):
self.logger.warning("{} does not exist!".format(img_path))
continue
if len(structure) == 0 or len(structure) > max_text_length:
continue
# data = {'img_path': img_path, 'cells': cells, 'structure':structure,'file_name':file_name}
data_lines.append(line)
self.data_lines = data_lines
def shuffle_data_random(self): def shuffle_data_random(self):
if self.do_shuffle: if self.do_shuffle:
random.seed(self.seed) random.seed(self.seed)
...@@ -68,47 +99,35 @@ class PubTabDataSet(Dataset): ...@@ -68,47 +99,35 @@ class PubTabDataSet(Dataset):
data_line = data_line.decode('utf-8').strip("\n") data_line = data_line.decode('utf-8').strip("\n")
info = json.loads(data_line) info = json.loads(data_line)
file_name = info['filename'] file_name = info['filename']
select_flag = True cells = info['html']['cells'].copy()
if self.do_hard_select: structure = info['html']['structure']['tokens'].copy()
prob = self.img_select_prob[file_name]
if prob < random.uniform(0, 1): img_path = os.path.join(self.data_dir, file_name)
select_flag = False if not os.path.exists(img_path):
raise Exception("{} does not exist!".format(img_path))
if self.table_select_type: data = {
structure = info['html']['structure']['tokens'].copy() 'img_path': img_path,
structure_str = ''.join(structure) 'cells': cells,
table_type = "simple" 'structure': structure,
if 'colspan' in structure_str or 'rowspan' in structure_str: 'file_name': file_name
table_type = "complex" }
if table_type == "complex":
if self.table_select_prob < random.uniform(0, 1): with open(data['img_path'], 'rb') as f:
select_flag = False img = f.read()
data['image'] = img
if select_flag: outs = transform(data, self.ops)
cells = info['html']['cells'].copy() except:
structure = info['html']['structure'].copy() import traceback
img_path = os.path.join(self.data_dir, file_name) err = traceback.format_exc()
data = {
'img_path': img_path,
'cells': cells,
'structure': structure
}
if not os.path.exists(img_path):
raise Exception("{} does not exist!".format(img_path))
with open(data['img_path'], 'rb') as f:
img = f.read()
data['image'] = img
outs = transform(data, self.ops)
else:
outs = None
except Exception as e:
self.logger.error( self.logger.error(
"When parsing line {}, error happened with msg: {}".format( "When parsing line {}, error happened with msg: {}".format(
data_line, e)) data_line, err))
outs = None outs = None
if outs is None: if outs is None:
return self.__getitem__(np.random.randint(self.__len__())) rnd_idx = np.random.randint(self.__len__(
)) if self.mode == "train" else (idx + 1) % self.__len__()
return self.__getitem__(rnd_idx)
return outs return outs
def __len__(self): def __len__(self):
return len(self.data_idx_order_list) return len(self.data_lines)
...@@ -51,7 +51,7 @@ from .combined_loss import CombinedLoss ...@@ -51,7 +51,7 @@ from .combined_loss import CombinedLoss
# table loss # table loss
from .table_att_loss import TableAttentionLoss from .table_att_loss import TableAttentionLoss
from .table_master_loss import TableMasterLoss
# vqa token loss # vqa token loss
from .vqa_token_layoutlm_loss import VQASerTokenLayoutLMLoss from .vqa_token_layoutlm_loss import VQASerTokenLayoutLMLoss
...@@ -61,7 +61,8 @@ def build_loss(config): ...@@ -61,7 +61,8 @@ def build_loss(config):
'DBLoss', 'PSELoss', 'EASTLoss', 'SASTLoss', 'FCELoss', 'CTCLoss', 'DBLoss', 'PSELoss', 'EASTLoss', 'SASTLoss', 'FCELoss', 'CTCLoss',
'ClsLoss', 'AttentionLoss', 'SRNLoss', 'PGLoss', 'CombinedLoss', 'ClsLoss', 'AttentionLoss', 'SRNLoss', 'PGLoss', 'CombinedLoss',
'CELoss', 'TableAttentionLoss', 'SARLoss', 'AsterLoss', 'SDMGRLoss', 'CELoss', 'TableAttentionLoss', 'SARLoss', 'AsterLoss', 'SDMGRLoss',
'VQASerTokenLayoutLMLoss', 'LossFromOutput', 'PRENLoss', 'MultiLoss' 'VQASerTokenLayoutLMLoss', 'LossFromOutput', 'PRENLoss', 'MultiLoss',
'TableMasterLoss'
] ]
config = copy.deepcopy(config) config = copy.deepcopy(config)
module_name = config.pop('name') module_name = config.pop('name')
......
...@@ -57,17 +57,24 @@ class CELoss(nn.Layer): ...@@ -57,17 +57,24 @@ class CELoss(nn.Layer):
class KLJSLoss(object): class KLJSLoss(object):
def __init__(self, mode='kl'): def __init__(self, mode='kl'):
assert mode in ['kl', 'js', 'KL', 'JS' assert mode in ['kl', 'js', 'KL', 'JS'
], "mode can only be one of ['kl', 'js', 'KL', 'JS']" ], "mode can only be one of ['kl', 'KL', 'js', 'JS']"
self.mode = mode self.mode = mode
def __call__(self, p1, p2, reduction="mean"): def __call__(self, p1, p2, reduction="mean"):
loss = paddle.multiply(p2, paddle.log((p2 + 1e-5) / (p1 + 1e-5) + 1e-5)) if self.mode.lower() == 'kl':
loss = paddle.multiply(p2, paddle.log((p2 + 1e-5) / (p1 + 1e-5) + 1e-5))
if self.mode.lower() == "js": loss += paddle.multiply(
p1, paddle.log((p1 + 1e-5) / (p2 + 1e-5) + 1e-5))
loss *= 0.5
elif self.mode.lower() == "js":
loss = paddle.multiply(p2, paddle.log((2*p2 + 1e-5) / (p1 + p2 + 1e-5) + 1e-5))
loss += paddle.multiply( loss += paddle.multiply(
p1, paddle.log((p1 + 1e-5) / (p2 + 1e-5) + 1e-5)) p1, paddle.log((2*p1 + 1e-5) / (p1 + p2 + 1e-5) + 1e-5))
loss *= 0.5 loss *= 0.5
else:
raise ValueError("The mode.lower() if KLJSLoss should be one of ['kl', 'js']")
if reduction == "mean": if reduction == "mean":
loss = paddle.mean(loss, axis=[1, 2]) loss = paddle.mean(loss, axis=[1, 2])
elif reduction == "none" or reduction is None: elif reduction == "none" or reduction is None:
...@@ -95,7 +102,7 @@ class DMLLoss(nn.Layer): ...@@ -95,7 +102,7 @@ class DMLLoss(nn.Layer):
self.act = None self.act = None
self.use_log = use_log self.use_log = use_log
self.jskl_loss = KLJSLoss(mode="js") self.jskl_loss = KLJSLoss(mode="kl")
def _kldiv(self, x, target): def _kldiv(self, x, target):
eps = 1.0e-10 eps = 1.0e-10
......
...@@ -20,15 +20,21 @@ import paddle ...@@ -20,15 +20,21 @@ import paddle
from paddle import nn from paddle import nn
from paddle.nn import functional as F from paddle.nn import functional as F
class TableAttentionLoss(nn.Layer): class TableAttentionLoss(nn.Layer):
def __init__(self, structure_weight, loc_weight, use_giou=False, giou_weight=1.0, **kwargs): def __init__(self,
structure_weight,
loc_weight,
use_giou=False,
giou_weight=1.0,
**kwargs):
super(TableAttentionLoss, self).__init__() super(TableAttentionLoss, self).__init__()
self.loss_func = nn.CrossEntropyLoss(weight=None, reduction='none') self.loss_func = nn.CrossEntropyLoss(weight=None, reduction='none')
self.structure_weight = structure_weight self.structure_weight = structure_weight
self.loc_weight = loc_weight self.loc_weight = loc_weight
self.use_giou = use_giou self.use_giou = use_giou
self.giou_weight = giou_weight self.giou_weight = giou_weight
def giou_loss(self, preds, bbox, eps=1e-7, reduction='mean'): def giou_loss(self, preds, bbox, eps=1e-7, reduction='mean'):
''' '''
:param preds:[[x1,y1,x2,y2], [x1,y1,x2,y2],,,] :param preds:[[x1,y1,x2,y2], [x1,y1,x2,y2],,,]
...@@ -47,9 +53,10 @@ class TableAttentionLoss(nn.Layer): ...@@ -47,9 +53,10 @@ class TableAttentionLoss(nn.Layer):
inters = iw * ih inters = iw * ih
# union # union
uni = (preds[:, 2] - preds[:, 0] + 1e-3) * (preds[:, 3] - preds[:, 1] + 1e-3 uni = (preds[:, 2] - preds[:, 0] + 1e-3) * (
) + (bbox[:, 2] - bbox[:, 0] + 1e-3) * ( preds[:, 3] - preds[:, 1] + 1e-3) + (bbox[:, 2] - bbox[:, 0] + 1e-3
bbox[:, 3] - bbox[:, 1] + 1e-3) - inters + eps ) * (bbox[:, 3] - bbox[:, 1] +
1e-3) - inters + eps
# ious # ious
ious = inters / uni ious = inters / uni
...@@ -79,30 +86,34 @@ class TableAttentionLoss(nn.Layer): ...@@ -79,30 +86,34 @@ class TableAttentionLoss(nn.Layer):
structure_probs = predicts['structure_probs'] structure_probs = predicts['structure_probs']
structure_targets = batch[1].astype("int64") structure_targets = batch[1].astype("int64")
structure_targets = structure_targets[:, 1:] structure_targets = structure_targets[:, 1:]
if len(batch) == 6: structure_probs = paddle.reshape(structure_probs,
structure_mask = batch[5].astype("int64") [-1, structure_probs.shape[-1]])
structure_mask = structure_mask[:, 1:]
structure_mask = paddle.reshape(structure_mask, [-1])
structure_probs = paddle.reshape(structure_probs, [-1, structure_probs.shape[-1]])
structure_targets = paddle.reshape(structure_targets, [-1]) structure_targets = paddle.reshape(structure_targets, [-1])
structure_loss = self.loss_func(structure_probs, structure_targets) structure_loss = self.loss_func(structure_probs, structure_targets)
if len(batch) == 6:
structure_loss = structure_loss * structure_mask
# structure_loss = paddle.sum(structure_loss) * self.structure_weight
structure_loss = paddle.mean(structure_loss) * self.structure_weight structure_loss = paddle.mean(structure_loss) * self.structure_weight
loc_preds = predicts['loc_preds'] loc_preds = predicts['loc_preds']
loc_targets = batch[2].astype("float32") loc_targets = batch[2].astype("float32")
loc_targets_mask = batch[4].astype("float32") loc_targets_mask = batch[3].astype("float32")
loc_targets = loc_targets[:, 1:, :] loc_targets = loc_targets[:, 1:, :]
loc_targets_mask = loc_targets_mask[:, 1:, :] loc_targets_mask = loc_targets_mask[:, 1:, :]
loc_loss = F.mse_loss(loc_preds * loc_targets_mask, loc_targets) * self.loc_weight loc_loss = F.mse_loss(loc_preds * loc_targets_mask,
loc_targets) * self.loc_weight
if self.use_giou: if self.use_giou:
loc_loss_giou = self.giou_loss(loc_preds * loc_targets_mask, loc_targets) * self.giou_weight loc_loss_giou = self.giou_loss(loc_preds * loc_targets_mask,
loc_targets) * self.giou_weight
total_loss = structure_loss + loc_loss + loc_loss_giou total_loss = structure_loss + loc_loss + loc_loss_giou
return {'loss':total_loss, "structure_loss":structure_loss, "loc_loss":loc_loss, "loc_loss_giou":loc_loss_giou} return {
'loss': total_loss,
"structure_loss": structure_loss,
"loc_loss": loc_loss,
"loc_loss_giou": loc_loss_giou
}
else: else:
total_loss = structure_loss + loc_loss total_loss = structure_loss + loc_loss
return {'loss':total_loss, "structure_loss":structure_loss, "loc_loss":loc_loss} return {
\ No newline at end of file 'loss': total_loss,
"structure_loss": structure_loss,
"loc_loss": loc_loss
}
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This code is refer from:
https://github.com/JiaquanYe/TableMASTER-mmocr/tree/master/mmocr/models/textrecog/losses
"""
import paddle
from paddle import nn
class TableMasterLoss(nn.Layer):
def __init__(self, ignore_index=-1):
super(TableMasterLoss, self).__init__()
self.structure_loss = nn.CrossEntropyLoss(
ignore_index=ignore_index, reduction='mean')
self.box_loss = nn.L1Loss(reduction='sum')
self.eps = 1e-12
def forward(self, predicts, batch):
# structure_loss
structure_probs = predicts['structure_probs']
structure_targets = batch[1]
structure_targets = structure_targets[:, 1:]
structure_probs = structure_probs.reshape(
[-1, structure_probs.shape[-1]])
structure_targets = structure_targets.reshape([-1])
structure_loss = self.structure_loss(structure_probs, structure_targets)
structure_loss = structure_loss.mean()
losses = dict(structure_loss=structure_loss)
# box loss
bboxes_preds = predicts['loc_preds']
bboxes_targets = batch[2][:, 1:, :]
bbox_masks = batch[3][:, 1:]
# mask empty-bbox or non-bbox structure token's bbox.
masked_bboxes_preds = bboxes_preds * bbox_masks
masked_bboxes_targets = bboxes_targets * bbox_masks
# horizon loss (x and width)
horizon_sum_loss = self.box_loss(masked_bboxes_preds[:, :, 0::2],
masked_bboxes_targets[:, :, 0::2])
horizon_loss = horizon_sum_loss / (bbox_masks.sum() + self.eps)
# vertical loss (y and height)
vertical_sum_loss = self.box_loss(masked_bboxes_preds[:, :, 1::2],
masked_bboxes_targets[:, :, 1::2])
vertical_loss = vertical_sum_loss / (bbox_masks.sum() + self.eps)
horizon_loss = horizon_loss.mean()
vertical_loss = vertical_loss.mean()
all_loss = structure_loss + horizon_loss + vertical_loss
losses.update({
'loss': all_loss,
'horizon_bbox_loss': horizon_loss,
'vertical_bbox_loss': vertical_loss
})
return losses
...@@ -83,14 +83,10 @@ class DetectionIoUEvaluator(object): ...@@ -83,14 +83,10 @@ class DetectionIoUEvaluator(object):
evaluationLog = "" evaluationLog = ""
# print(len(gt))
for n in range(len(gt)): for n in range(len(gt)):
points = gt[n]['points'] points = gt[n]['points']
# transcription = gt[n]['text']
dontCare = gt[n]['ignore'] dontCare = gt[n]['ignore']
# points = Polygon(points) if not Polygon(points).is_valid:
# points = points.buffer(0)
if not Polygon(points).is_valid or not Polygon(points).is_simple:
continue continue
gtPol = points gtPol = points
...@@ -105,9 +101,7 @@ class DetectionIoUEvaluator(object): ...@@ -105,9 +101,7 @@ class DetectionIoUEvaluator(object):
for n in range(len(pred)): for n in range(len(pred)):
points = pred[n]['points'] points = pred[n]['points']
# points = Polygon(points) if not Polygon(points).is_valid:
# points = points.buffer(0)
if not Polygon(points).is_valid or not Polygon(points).is_simple:
continue continue
detPol = points detPol = points
...@@ -191,8 +185,6 @@ class DetectionIoUEvaluator(object): ...@@ -191,8 +185,6 @@ class DetectionIoUEvaluator(object):
methodHmean = 0 if methodRecall + methodPrecision == 0 else 2 * \ methodHmean = 0 if methodRecall + methodPrecision == 0 else 2 * \
methodRecall * methodPrecision / ( methodRecall * methodPrecision / (
methodRecall + methodPrecision) methodRecall + methodPrecision)
# print(methodRecall, methodPrecision, methodHmean)
# sys.exit(-1)
methodMetrics = { methodMetrics = {
'precision': methodPrecision, 'precision': methodPrecision,
'recall': methodRecall, 'recall': methodRecall,
......
...@@ -12,29 +12,30 @@ ...@@ -12,29 +12,30 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import numpy as np import numpy as np
from ppocr.metrics.det_metric import DetMetric
class TableMetric(object): class TableStructureMetric(object):
def __init__(self, main_indicator='acc', **kwargs): def __init__(self, main_indicator='acc', eps=1e-6, **kwargs):
self.main_indicator = main_indicator self.main_indicator = main_indicator
self.eps = 1e-5 self.eps = eps
self.reset() self.reset()
def __call__(self, pred, batch, *args, **kwargs): def __call__(self, pred_label, batch=None, *args, **kwargs):
structure_probs = pred['structure_probs'].numpy() preds, labels = pred_label
structure_labels = batch[1] pred_structure_batch_list = preds['structure_batch_list']
gt_structure_batch_list = labels['structure_batch_list']
correct_num = 0 correct_num = 0
all_num = 0 all_num = 0
structure_probs = np.argmax(structure_probs, axis=2) for (pred, pred_conf), target in zip(pred_structure_batch_list,
structure_labels = structure_labels[:, 1:] gt_structure_batch_list):
batch_size = structure_probs.shape[0] pred_str = ''.join(pred)
for bno in range(batch_size): target_str = ''.join(target)
all_num += 1 if pred_str == target_str:
if (structure_probs[bno] == structure_labels[bno]).all():
correct_num += 1 correct_num += 1
all_num += 1
self.correct_num += correct_num self.correct_num += correct_num
self.all_num += all_num self.all_num += all_num
return {'acc': correct_num * 1.0 / (all_num + self.eps), }
def get_metric(self): def get_metric(self):
""" """
...@@ -49,3 +50,89 @@ class TableMetric(object): ...@@ -49,3 +50,89 @@ class TableMetric(object):
def reset(self): def reset(self):
self.correct_num = 0 self.correct_num = 0
self.all_num = 0 self.all_num = 0
self.len_acc_num = 0
self.token_nums = 0
self.anys_dict = dict()
class TableMetric(object):
def __init__(self,
main_indicator='acc',
compute_bbox_metric=False,
point_num=2,
**kwargs):
"""
@param sub_metrics: configs of sub_metric
@param main_matric: main_matric for save best_model
@param kwargs:
"""
self.structure_metric = TableStructureMetric()
self.bbox_metric = DetMetric() if compute_bbox_metric else None
self.main_indicator = main_indicator
self.point_num = point_num
self.reset()
def __call__(self, pred_label, batch=None, *args, **kwargs):
self.structure_metric(pred_label)
if self.bbox_metric is not None:
self.bbox_metric(*self.prepare_bbox_metric_input(pred_label))
def prepare_bbox_metric_input(self, pred_label):
pred_bbox_batch_list = []
gt_ignore_tags_batch_list = []
gt_bbox_batch_list = []
preds, labels = pred_label
batch_num = len(preds['bbox_batch_list'])
for batch_idx in range(batch_num):
# pred
pred_bbox_list = [
self.format_box(pred_box)
for pred_box in preds['bbox_batch_list'][batch_idx]
]
pred_bbox_batch_list.append({'points': pred_bbox_list})
# gt
gt_bbox_list = []
gt_ignore_tags_list = []
for gt_box in labels['bbox_batch_list'][batch_idx]:
gt_bbox_list.append(self.format_box(gt_box))
gt_ignore_tags_list.append(0)
gt_bbox_batch_list.append(gt_bbox_list)
gt_ignore_tags_batch_list.append(gt_ignore_tags_list)
return [
pred_bbox_batch_list,
[0, 0, gt_bbox_batch_list, gt_ignore_tags_batch_list]
]
def get_metric(self):
structure_metric = self.structure_metric.get_metric()
if self.bbox_metric is None:
return structure_metric
bbox_metric = self.bbox_metric.get_metric()
if self.main_indicator == self.bbox_metric.main_indicator:
output = bbox_metric
for sub_key in structure_metric:
output["structure_metric_{}".format(
sub_key)] = structure_metric[sub_key]
else:
output = structure_metric
for sub_key in bbox_metric:
output["bbox_metric_{}".format(sub_key)] = bbox_metric[sub_key]
return output
def reset(self):
self.structure_metric.reset()
if self.bbox_metric is not None:
self.bbox_metric.reset()
def format_box(self, box):
if self.point_num == 2:
x1, y1, x2, y2 = box
box = [[x1, y1], [x2, y1], [x2, y2], [x1, y2]]
elif self.point_num == 4:
x1, y1, x2, y2, x3, y3, x4, y4 = box
box = [[x1, y1], [x2, y2], [x3, y3], [x4, y4]]
return box
...@@ -37,23 +37,26 @@ class VQAReTokenMetric(object): ...@@ -37,23 +37,26 @@ class VQAReTokenMetric(object):
gt_relations = [] gt_relations = []
for b in range(len(self.relations_list)): for b in range(len(self.relations_list)):
rel_sent = [] rel_sent = []
for head, tail in zip(self.relations_list[b]["head"], if "head" in self.relations_list[b]:
self.relations_list[b]["tail"]): for head, tail in zip(self.relations_list[b]["head"],
rel = {} self.relations_list[b]["tail"]):
rel["head_id"] = head rel = {}
rel["head"] = (self.entities_list[b]["start"][rel["head_id"]], rel["head_id"] = head
self.entities_list[b]["end"][rel["head_id"]]) rel["head"] = (
rel["head_type"] = self.entities_list[b]["label"][rel[ self.entities_list[b]["start"][rel["head_id"]],
"head_id"]] self.entities_list[b]["end"][rel["head_id"]])
rel["head_type"] = self.entities_list[b]["label"][rel[
rel["tail_id"] = tail "head_id"]]
rel["tail"] = (self.entities_list[b]["start"][rel["tail_id"]],
self.entities_list[b]["end"][rel["tail_id"]]) rel["tail_id"] = tail
rel["tail_type"] = self.entities_list[b]["label"][rel[ rel["tail"] = (
"tail_id"]] self.entities_list[b]["start"][rel["tail_id"]],
self.entities_list[b]["end"][rel["tail_id"]])
rel["type"] = 1 rel["tail_type"] = self.entities_list[b]["label"][rel[
rel_sent.append(rel) "tail_id"]]
rel["type"] = 1
rel_sent.append(rel)
gt_relations.append(rel_sent) gt_relations.append(rel_sent)
re_metrics = self.re_score( re_metrics = self.re_score(
self.pred_relations_list, gt_relations, mode="boundaries") self.pred_relations_list, gt_relations, mode="boundaries")
......
...@@ -22,6 +22,9 @@ def build_backbone(config, model_type): ...@@ -22,6 +22,9 @@ def build_backbone(config, model_type):
from .det_resnet_vd import ResNet_vd from .det_resnet_vd import ResNet_vd
from .det_resnet_vd_sast import ResNet_SAST from .det_resnet_vd_sast import ResNet_SAST
support_dict = ["MobileNetV3", "ResNet", "ResNet_vd", "ResNet_SAST"] support_dict = ["MobileNetV3", "ResNet", "ResNet_vd", "ResNet_SAST"]
if model_type == "table":
from .table_master_resnet import TableResNetExtra
support_dict.append('TableResNetExtra')
elif model_type == "rec" or model_type == "cls": elif model_type == "rec" or model_type == "cls":
from .rec_mobilenet_v3 import MobileNetV3 from .rec_mobilenet_v3 import MobileNetV3
from .rec_resnet_vd import ResNet from .rec_resnet_vd import ResNet
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This code is refer from:
https://github.com/JiaquanYe/TableMASTER-mmocr/blob/master/mmocr/models/textrecog/backbones/table_resnet_extra.py
"""
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
class BasicBlock(nn.Layer):
expansion = 1
def __init__(self,
inplanes,
planes,
stride=1,
downsample=None,
gcb_config=None):
super(BasicBlock, self).__init__()
self.conv1 = nn.Conv2D(
inplanes,
planes,
kernel_size=3,
stride=stride,
padding=1,
bias_attr=False)
self.bn1 = nn.BatchNorm2D(planes, momentum=0.9)
self.relu = nn.ReLU()
self.conv2 = nn.Conv2D(
planes, planes, kernel_size=3, stride=1, padding=1, bias_attr=False)
self.bn2 = nn.BatchNorm2D(planes, momentum=0.9)
self.downsample = downsample
self.stride = stride
self.gcb_config = gcb_config
if self.gcb_config is not None:
gcb_ratio = gcb_config['ratio']
gcb_headers = gcb_config['headers']
att_scale = gcb_config['att_scale']
fusion_type = gcb_config['fusion_type']
self.context_block = MultiAspectGCAttention(
inplanes=planes,
ratio=gcb_ratio,
headers=gcb_headers,
att_scale=att_scale,
fusion_type=fusion_type)
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.gcb_config is not None:
out = self.context_block(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.relu(out)
return out
def get_gcb_config(gcb_config, layer):
if gcb_config is None or not gcb_config['layers'][layer]:
return None
else:
return gcb_config
class TableResNetExtra(nn.Layer):
def __init__(self, layers, in_channels=3, gcb_config=None):
assert len(layers) >= 4
super(TableResNetExtra, self).__init__()
self.inplanes = 128
self.conv1 = nn.Conv2D(
in_channels,
64,
kernel_size=3,
stride=1,
padding=1,
bias_attr=False)
self.bn1 = nn.BatchNorm2D(64)
self.relu1 = nn.ReLU()
self.conv2 = nn.Conv2D(
64, 128, kernel_size=3, stride=1, padding=1, bias_attr=False)
self.bn2 = nn.BatchNorm2D(128)
self.relu2 = nn.ReLU()
self.maxpool1 = nn.MaxPool2D(kernel_size=2, stride=2)
self.layer1 = self._make_layer(
BasicBlock,
256,
layers[0],
stride=1,
gcb_config=get_gcb_config(gcb_config, 0))
self.conv3 = nn.Conv2D(
256, 256, kernel_size=3, stride=1, padding=1, bias_attr=False)
self.bn3 = nn.BatchNorm2D(256)
self.relu3 = nn.ReLU()
self.maxpool2 = nn.MaxPool2D(kernel_size=2, stride=2)
self.layer2 = self._make_layer(
BasicBlock,
256,
layers[1],
stride=1,
gcb_config=get_gcb_config(gcb_config, 1))
self.conv4 = nn.Conv2D(
256, 256, kernel_size=3, stride=1, padding=1, bias_attr=False)
self.bn4 = nn.BatchNorm2D(256)
self.relu4 = nn.ReLU()
self.maxpool3 = nn.MaxPool2D(kernel_size=2, stride=2)
self.layer3 = self._make_layer(
BasicBlock,
512,
layers[2],
stride=1,
gcb_config=get_gcb_config(gcb_config, 2))
self.conv5 = nn.Conv2D(
512, 512, kernel_size=3, stride=1, padding=1, bias_attr=False)
self.bn5 = nn.BatchNorm2D(512)
self.relu5 = nn.ReLU()
self.layer4 = self._make_layer(
BasicBlock,
512,
layers[3],
stride=1,
gcb_config=get_gcb_config(gcb_config, 3))
self.conv6 = nn.Conv2D(
512, 512, kernel_size=3, stride=1, padding=1, bias_attr=False)
self.bn6 = nn.BatchNorm2D(512)
self.relu6 = nn.ReLU()
self.out_channels = [256, 256, 512]
def _make_layer(self, block, planes, blocks, stride=1, gcb_config=None):
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
nn.Conv2D(
self.inplanes,
planes * block.expansion,
kernel_size=1,
stride=stride,
bias_attr=False),
nn.BatchNorm2D(planes * block.expansion), )
layers = []
layers.append(
block(
self.inplanes,
planes,
stride,
downsample,
gcb_config=gcb_config))
self.inplanes = planes * block.expansion
for _ in range(1, blocks):
layers.append(block(self.inplanes, planes))
return nn.Sequential(*layers)
def forward(self, x):
f = []
x = self.conv1(x)
x = self.bn1(x)
x = self.relu1(x)
x = self.conv2(x)
x = self.bn2(x)
x = self.relu2(x)
x = self.maxpool1(x)
x = self.layer1(x)
x = self.conv3(x)
x = self.bn3(x)
x = self.relu3(x)
f.append(x)
x = self.maxpool2(x)
x = self.layer2(x)
x = self.conv4(x)
x = self.bn4(x)
x = self.relu4(x)
f.append(x)
x = self.maxpool3(x)
x = self.layer3(x)
x = self.conv5(x)
x = self.bn5(x)
x = self.relu5(x)
x = self.layer4(x)
x = self.conv6(x)
x = self.bn6(x)
x = self.relu6(x)
f.append(x)
return f
class MultiAspectGCAttention(nn.Layer):
def __init__(self,
inplanes,
ratio,
headers,
pooling_type='att',
att_scale=False,
fusion_type='channel_add'):
super(MultiAspectGCAttention, self).__init__()
assert pooling_type in ['avg', 'att']
assert fusion_type in ['channel_add', 'channel_mul', 'channel_concat']
assert inplanes % headers == 0 and inplanes >= 8 # inplanes must be divided by headers evenly
self.headers = headers
self.inplanes = inplanes
self.ratio = ratio
self.planes = int(inplanes * ratio)
self.pooling_type = pooling_type
self.fusion_type = fusion_type
self.att_scale = False
self.single_header_inplanes = int(inplanes / headers)
if pooling_type == 'att':
self.conv_mask = nn.Conv2D(
self.single_header_inplanes, 1, kernel_size=1)
self.softmax = nn.Softmax(axis=2)
else:
self.avg_pool = nn.AdaptiveAvgPool2D(1)
if fusion_type == 'channel_add':
self.channel_add_conv = nn.Sequential(
nn.Conv2D(
self.inplanes, self.planes, kernel_size=1),
nn.LayerNorm([self.planes, 1, 1]),
nn.ReLU(),
nn.Conv2D(
self.planes, self.inplanes, kernel_size=1))
elif fusion_type == 'channel_concat':
self.channel_concat_conv = nn.Sequential(
nn.Conv2D(
self.inplanes, self.planes, kernel_size=1),
nn.LayerNorm([self.planes, 1, 1]),
nn.ReLU(),
nn.Conv2D(
self.planes, self.inplanes, kernel_size=1))
# for concat
self.cat_conv = nn.Conv2D(
2 * self.inplanes, self.inplanes, kernel_size=1)
elif fusion_type == 'channel_mul':
self.channel_mul_conv = nn.Sequential(
nn.Conv2D(
self.inplanes, self.planes, kernel_size=1),
nn.LayerNorm([self.planes, 1, 1]),
nn.ReLU(),
nn.Conv2D(
self.planes, self.inplanes, kernel_size=1))
def spatial_pool(self, x):
batch, channel, height, width = x.shape
if self.pooling_type == 'att':
# [N*headers, C', H , W] C = headers * C'
x = x.reshape([
batch * self.headers, self.single_header_inplanes, height, width
])
input_x = x
# [N*headers, C', H * W] C = headers * C'
# input_x = input_x.view(batch, channel, height * width)
input_x = input_x.reshape([
batch * self.headers, self.single_header_inplanes,
height * width
])
# [N*headers, 1, C', H * W]
input_x = input_x.unsqueeze(1)
# [N*headers, 1, H, W]
context_mask = self.conv_mask(x)
# [N*headers, 1, H * W]
context_mask = context_mask.reshape(
[batch * self.headers, 1, height * width])
# scale variance
if self.att_scale and self.headers > 1:
context_mask = context_mask / paddle.sqrt(
self.single_header_inplanes)
# [N*headers, 1, H * W]
context_mask = self.softmax(context_mask)
# [N*headers, 1, H * W, 1]
context_mask = context_mask.unsqueeze(-1)
# [N*headers, 1, C', 1] = [N*headers, 1, C', H * W] * [N*headers, 1, H * W, 1]
context = paddle.matmul(input_x, context_mask)
# [N, headers * C', 1, 1]
context = context.reshape(
[batch, self.headers * self.single_header_inplanes, 1, 1])
else:
# [N, C, 1, 1]
context = self.avg_pool(x)
return context
def forward(self, x):
# [N, C, 1, 1]
context = self.spatial_pool(x)
out = x
if self.fusion_type == 'channel_mul':
# [N, C, 1, 1]
channel_mul_term = F.sigmoid(self.channel_mul_conv(context))
out = out * channel_mul_term
elif self.fusion_type == 'channel_add':
# [N, C, 1, 1]
channel_add_term = self.channel_add_conv(context)
out = out + channel_add_term
else:
# [N, C, 1, 1]
channel_concat_term = self.channel_concat_conv(context)
# use concat
_, C1, _, _ = channel_concat_term.shape
N, C2, H, W = out.shape
out = paddle.concat(
[out, channel_concat_term.expand([-1, -1, H, W])], axis=1)
out = self.cat_conv(out)
out = F.layer_norm(out, [self.inplanes, H, W])
out = F.relu(out)
return out
...@@ -43,9 +43,11 @@ class NLPBaseModel(nn.Layer): ...@@ -43,9 +43,11 @@ class NLPBaseModel(nn.Layer):
super(NLPBaseModel, self).__init__() super(NLPBaseModel, self).__init__()
if checkpoints is not None: if checkpoints is not None:
self.model = model_class.from_pretrained(checkpoints) self.model = model_class.from_pretrained(checkpoints)
elif isinstance(pretrained, (str, )) and os.path.exists(pretrained):
self.model = model_class.from_pretrained(pretrained)
else: else:
pretrained_model_name = pretrained_model_dict[base_model_class] pretrained_model_name = pretrained_model_dict[base_model_class]
if pretrained: if pretrained is True:
base_model = base_model_class.from_pretrained( base_model = base_model_class.from_pretrained(
pretrained_model_name) pretrained_model_name)
else: else:
......
...@@ -42,12 +42,13 @@ def build_head(config): ...@@ -42,12 +42,13 @@ def build_head(config):
from .kie_sdmgr_head import SDMGRHead from .kie_sdmgr_head import SDMGRHead
from .table_att_head import TableAttentionHead from .table_att_head import TableAttentionHead
from .table_master_head import TableMasterHead
support_dict = [ support_dict = [
'DBHead', 'PSEHead', 'FCEHead', 'EASTHead', 'SASTHead', 'CTCHead', 'DBHead', 'PSEHead', 'FCEHead', 'EASTHead', 'SASTHead', 'CTCHead',
'ClsHead', 'AttentionHead', 'SRNHead', 'PGHead', 'Transformer', 'ClsHead', 'AttentionHead', 'SRNHead', 'PGHead', 'Transformer',
'TableAttentionHead', 'SARHead', 'AsterHead', 'SDMGRHead', 'PRENHead', 'TableAttentionHead', 'SARHead', 'AsterHead', 'SDMGRHead', 'PRENHead',
'MultiHead', 'ABINetHead' 'MultiHead', 'ABINetHead', 'TableMasterHead'
] ]
#table head #table head
......
...@@ -21,6 +21,8 @@ import paddle.nn as nn ...@@ -21,6 +21,8 @@ import paddle.nn as nn
import paddle.nn.functional as F import paddle.nn.functional as F
import numpy as np import numpy as np
from .rec_att_head import AttentionGRUCell
class TableAttentionHead(nn.Layer): class TableAttentionHead(nn.Layer):
def __init__(self, def __init__(self,
...@@ -28,21 +30,19 @@ class TableAttentionHead(nn.Layer): ...@@ -28,21 +30,19 @@ class TableAttentionHead(nn.Layer):
hidden_size, hidden_size,
loc_type, loc_type,
in_max_len=488, in_max_len=488,
max_text_length=100, max_text_length=800,
max_elem_length=800, out_channels=30,
max_cell_num=500, point_num=2,
**kwargs): **kwargs):
super(TableAttentionHead, self).__init__() super(TableAttentionHead, self).__init__()
self.input_size = in_channels[-1] self.input_size = in_channels[-1]
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.elem_num = 30 self.out_channels = out_channels
self.max_text_length = max_text_length self.max_text_length = max_text_length
self.max_elem_length = max_elem_length
self.max_cell_num = max_cell_num
self.structure_attention_cell = AttentionGRUCell( self.structure_attention_cell = AttentionGRUCell(
self.input_size, hidden_size, self.elem_num, use_gru=False) self.input_size, hidden_size, self.out_channels, use_gru=False)
self.structure_generator = nn.Linear(hidden_size, self.elem_num) self.structure_generator = nn.Linear(hidden_size, self.out_channels)
self.loc_type = loc_type self.loc_type = loc_type
self.in_max_len = in_max_len self.in_max_len = in_max_len
...@@ -50,12 +50,13 @@ class TableAttentionHead(nn.Layer): ...@@ -50,12 +50,13 @@ class TableAttentionHead(nn.Layer):
self.loc_generator = nn.Linear(hidden_size, 4) self.loc_generator = nn.Linear(hidden_size, 4)
else: else:
if self.in_max_len == 640: if self.in_max_len == 640:
self.loc_fea_trans = nn.Linear(400, self.max_elem_length + 1) self.loc_fea_trans = nn.Linear(400, self.max_text_length + 1)
elif self.in_max_len == 800: elif self.in_max_len == 800:
self.loc_fea_trans = nn.Linear(625, self.max_elem_length + 1) self.loc_fea_trans = nn.Linear(625, self.max_text_length + 1)
else: else:
self.loc_fea_trans = nn.Linear(256, self.max_elem_length + 1) self.loc_fea_trans = nn.Linear(256, self.max_text_length + 1)
self.loc_generator = nn.Linear(self.input_size + hidden_size, 4) self.loc_generator = nn.Linear(self.input_size + hidden_size,
point_num * 2)
def _char_to_onehot(self, input_char, onehot_dim): def _char_to_onehot(self, input_char, onehot_dim):
input_ont_hot = F.one_hot(input_char, onehot_dim) input_ont_hot = F.one_hot(input_char, onehot_dim)
...@@ -77,9 +78,9 @@ class TableAttentionHead(nn.Layer): ...@@ -77,9 +78,9 @@ class TableAttentionHead(nn.Layer):
output_hiddens = [] output_hiddens = []
if self.training and targets is not None: if self.training and targets is not None:
structure = targets[0] structure = targets[0]
for i in range(self.max_elem_length + 1): for i in range(self.max_text_length + 1):
elem_onehots = self._char_to_onehot( elem_onehots = self._char_to_onehot(
structure[:, i], onehot_dim=self.elem_num) structure[:, i], onehot_dim=self.out_channels)
(outputs, hidden), alpha = self.structure_attention_cell( (outputs, hidden), alpha = self.structure_attention_cell(
hidden, fea, elem_onehots) hidden, fea, elem_onehots)
output_hiddens.append(paddle.unsqueeze(outputs, axis=1)) output_hiddens.append(paddle.unsqueeze(outputs, axis=1))
...@@ -102,11 +103,11 @@ class TableAttentionHead(nn.Layer): ...@@ -102,11 +103,11 @@ class TableAttentionHead(nn.Layer):
elem_onehots = None elem_onehots = None
outputs = None outputs = None
alpha = None alpha = None
max_elem_length = paddle.to_tensor(self.max_elem_length) max_text_length = paddle.to_tensor(self.max_text_length)
i = 0 i = 0
while i < max_elem_length + 1: while i < max_text_length + 1:
elem_onehots = self._char_to_onehot( elem_onehots = self._char_to_onehot(
temp_elem, onehot_dim=self.elem_num) temp_elem, onehot_dim=self.out_channels)
(outputs, hidden), alpha = self.structure_attention_cell( (outputs, hidden), alpha = self.structure_attention_cell(
hidden, fea, elem_onehots) hidden, fea, elem_onehots)
output_hiddens.append(paddle.unsqueeze(outputs, axis=1)) output_hiddens.append(paddle.unsqueeze(outputs, axis=1))
...@@ -128,119 +129,3 @@ class TableAttentionHead(nn.Layer): ...@@ -128,119 +129,3 @@ class TableAttentionHead(nn.Layer):
loc_preds = self.loc_generator(loc_concat) loc_preds = self.loc_generator(loc_concat)
loc_preds = F.sigmoid(loc_preds) loc_preds = F.sigmoid(loc_preds)
return {'structure_probs': structure_probs, 'loc_preds': loc_preds} return {'structure_probs': structure_probs, 'loc_preds': loc_preds}
class AttentionGRUCell(nn.Layer):
def __init__(self, input_size, hidden_size, num_embeddings, use_gru=False):
super(AttentionGRUCell, self).__init__()
self.i2h = nn.Linear(input_size, hidden_size, bias_attr=False)
self.h2h = nn.Linear(hidden_size, hidden_size)
self.score = nn.Linear(hidden_size, 1, bias_attr=False)
self.rnn = nn.GRUCell(
input_size=input_size + num_embeddings, hidden_size=hidden_size)
self.hidden_size = hidden_size
def forward(self, prev_hidden, batch_H, char_onehots):
batch_H_proj = self.i2h(batch_H)
prev_hidden_proj = paddle.unsqueeze(self.h2h(prev_hidden), axis=1)
res = paddle.add(batch_H_proj, prev_hidden_proj)
res = paddle.tanh(res)
e = self.score(res)
alpha = F.softmax(e, axis=1)
alpha = paddle.transpose(alpha, [0, 2, 1])
context = paddle.squeeze(paddle.mm(alpha, batch_H), axis=1)
concat_context = paddle.concat([context, char_onehots], 1)
cur_hidden = self.rnn(concat_context, prev_hidden)
return cur_hidden, alpha
class AttentionLSTM(nn.Layer):
def __init__(self, in_channels, out_channels, hidden_size, **kwargs):
super(AttentionLSTM, self).__init__()
self.input_size = in_channels
self.hidden_size = hidden_size
self.num_classes = out_channels
self.attention_cell = AttentionLSTMCell(
in_channels, hidden_size, out_channels, use_gru=False)
self.generator = nn.Linear(hidden_size, out_channels)
def _char_to_onehot(self, input_char, onehot_dim):
input_ont_hot = F.one_hot(input_char, onehot_dim)
return input_ont_hot
def forward(self, inputs, targets=None, batch_max_length=25):
batch_size = inputs.shape[0]
num_steps = batch_max_length
hidden = (paddle.zeros((batch_size, self.hidden_size)), paddle.zeros(
(batch_size, self.hidden_size)))
output_hiddens = []
if targets is not None:
for i in range(num_steps):
# one-hot vectors for a i-th char
char_onehots = self._char_to_onehot(
targets[:, i], onehot_dim=self.num_classes)
hidden, alpha = self.attention_cell(hidden, inputs,
char_onehots)
hidden = (hidden[1][0], hidden[1][1])
output_hiddens.append(paddle.unsqueeze(hidden[0], axis=1))
output = paddle.concat(output_hiddens, axis=1)
probs = self.generator(output)
else:
targets = paddle.zeros(shape=[batch_size], dtype="int32")
probs = None
for i in range(num_steps):
char_onehots = self._char_to_onehot(
targets, onehot_dim=self.num_classes)
hidden, alpha = self.attention_cell(hidden, inputs,
char_onehots)
probs_step = self.generator(hidden[0])
hidden = (hidden[1][0], hidden[1][1])
if probs is None:
probs = paddle.unsqueeze(probs_step, axis=1)
else:
probs = paddle.concat(
[probs, paddle.unsqueeze(
probs_step, axis=1)], axis=1)
next_input = probs_step.argmax(axis=1)
targets = next_input
return probs
class AttentionLSTMCell(nn.Layer):
def __init__(self, input_size, hidden_size, num_embeddings, use_gru=False):
super(AttentionLSTMCell, self).__init__()
self.i2h = nn.Linear(input_size, hidden_size, bias_attr=False)
self.h2h = nn.Linear(hidden_size, hidden_size)
self.score = nn.Linear(hidden_size, 1, bias_attr=False)
if not use_gru:
self.rnn = nn.LSTMCell(
input_size=input_size + num_embeddings, hidden_size=hidden_size)
else:
self.rnn = nn.GRUCell(
input_size=input_size + num_embeddings, hidden_size=hidden_size)
self.hidden_size = hidden_size
def forward(self, prev_hidden, batch_H, char_onehots):
batch_H_proj = self.i2h(batch_H)
prev_hidden_proj = paddle.unsqueeze(self.h2h(prev_hidden[0]), axis=1)
res = paddle.add(batch_H_proj, prev_hidden_proj)
res = paddle.tanh(res)
e = self.score(res)
alpha = F.softmax(e, axis=1)
alpha = paddle.transpose(alpha, [0, 2, 1])
context = paddle.squeeze(paddle.mm(alpha, batch_H), axis=1)
concat_context = paddle.concat([context, char_onehots], 1)
cur_hidden = self.rnn(concat_context, prev_hidden)
return cur_hidden, alpha
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This code is refer from:
https://github.com/JiaquanYe/TableMASTER-mmocr/blob/master/mmocr/models/textrecog/decoders/master_decoder.py
"""
import copy
import math
import paddle
from paddle import nn
from paddle.nn import functional as F
class TableMasterHead(nn.Layer):
"""
Split to two transformer header at the last layer.
Cls_layer is used to structure token classification.
Bbox_layer is used to regress bbox coord.
"""
def __init__(self,
in_channels,
out_channels=30,
headers=8,
d_ff=2048,
dropout=0,
max_text_length=500,
point_num=2,
**kwargs):
super(TableMasterHead, self).__init__()
hidden_size = in_channels[-1]
self.layers = clones(
DecoderLayer(headers, hidden_size, dropout, d_ff), 2)
self.cls_layer = clones(
DecoderLayer(headers, hidden_size, dropout, d_ff), 1)
self.bbox_layer = clones(
DecoderLayer(headers, hidden_size, dropout, d_ff), 1)
self.cls_fc = nn.Linear(hidden_size, out_channels)
self.bbox_fc = nn.Sequential(
# nn.Linear(hidden_size, hidden_size),
nn.Linear(hidden_size, point_num * 2),
nn.Sigmoid())
self.norm = nn.LayerNorm(hidden_size)
self.embedding = Embeddings(d_model=hidden_size, vocab=out_channels)
self.positional_encoding = PositionalEncoding(d_model=hidden_size)
self.SOS = out_channels - 3
self.PAD = out_channels - 1
self.out_channels = out_channels
self.point_num = point_num
self.max_text_length = max_text_length
def make_mask(self, tgt):
"""
Make mask for self attention.
:param src: [b, c, h, l_src]
:param tgt: [b, l_tgt]
:return:
"""
trg_pad_mask = (tgt != self.PAD).unsqueeze(1).unsqueeze(3)
tgt_len = paddle.shape(tgt)[1]
trg_sub_mask = paddle.tril(
paddle.ones(
([tgt_len, tgt_len]), dtype=paddle.float32))
tgt_mask = paddle.logical_and(
trg_pad_mask.astype(paddle.float32), trg_sub_mask)
return tgt_mask.astype(paddle.float32)
def decode(self, input, feature, src_mask, tgt_mask):
# main process of transformer decoder.
x = self.embedding(input) # x: 1*x*512, feature: 1*3600,512
x = self.positional_encoding(x)
# origin transformer layers
for i, layer in enumerate(self.layers):
x = layer(x, feature, src_mask, tgt_mask)
# cls head
for layer in self.cls_layer:
cls_x = layer(x, feature, src_mask, tgt_mask)
cls_x = self.norm(cls_x)
# bbox head
for layer in self.bbox_layer:
bbox_x = layer(x, feature, src_mask, tgt_mask)
bbox_x = self.norm(bbox_x)
return self.cls_fc(cls_x), self.bbox_fc(bbox_x)
def greedy_forward(self, SOS, feature):
input = SOS
output = paddle.zeros(
[input.shape[0], self.max_text_length + 1, self.out_channels])
bbox_output = paddle.zeros(
[input.shape[0], self.max_text_length + 1, self.point_num * 2])
max_text_length = paddle.to_tensor(self.max_text_length)
for i in range(max_text_length + 1):
target_mask = self.make_mask(input)
out_step, bbox_output_step = self.decode(input, feature, None,
target_mask)
prob = F.softmax(out_step, axis=-1)
next_word = prob.argmax(axis=2, dtype="int64")
input = paddle.concat(
[input, next_word[:, -1].unsqueeze(-1)], axis=1)
if i == self.max_text_length:
output = out_step
bbox_output = bbox_output_step
return output, bbox_output
def forward_train(self, out_enc, targets):
# x is token of label
# feat is feature after backbone before pe.
# out_enc is feature after pe.
padded_targets = targets[0]
src_mask = None
tgt_mask = self.make_mask(padded_targets[:, :-1])
output, bbox_output = self.decode(padded_targets[:, :-1], out_enc,
src_mask, tgt_mask)
return {'structure_probs': output, 'loc_preds': bbox_output}
def forward_test(self, out_enc):
batch_size = out_enc.shape[0]
SOS = paddle.zeros([batch_size, 1], dtype='int64') + self.SOS
output, bbox_output = self.greedy_forward(SOS, out_enc)
output = F.softmax(output)
return {'structure_probs': output, 'loc_preds': bbox_output}
def forward(self, feat, targets=None):
feat = feat[-1]
b, c, h, w = feat.shape
feat = feat.reshape([b, c, h * w]) # flatten 2D feature map
feat = feat.transpose((0, 2, 1))
out_enc = self.positional_encoding(feat)
if self.training:
return self.forward_train(out_enc, targets)
return self.forward_test(out_enc)
class DecoderLayer(nn.Layer):
"""
Decoder is made of self attention, srouce attention and feed forward.
"""
def __init__(self, headers, d_model, dropout, d_ff):
super(DecoderLayer, self).__init__()
self.self_attn = MultiHeadAttention(headers, d_model, dropout)
self.src_attn = MultiHeadAttention(headers, d_model, dropout)
self.feed_forward = FeedForward(d_model, d_ff, dropout)
self.sublayer = clones(SubLayerConnection(d_model, dropout), 3)
def forward(self, x, feature, src_mask, tgt_mask):
x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, tgt_mask))
x = self.sublayer[1](
x, lambda x: self.src_attn(x, feature, feature, src_mask))
return self.sublayer[2](x, self.feed_forward)
class MultiHeadAttention(nn.Layer):
def __init__(self, headers, d_model, dropout):
super(MultiHeadAttention, self).__init__()
assert d_model % headers == 0
self.d_k = int(d_model / headers)
self.headers = headers
self.linears = clones(nn.Linear(d_model, d_model), 4)
self.attn = None
self.dropout = nn.Dropout(dropout)
def forward(self, query, key, value, mask=None):
B = query.shape[0]
# 1) Do all the linear projections in batch from d_model => h x d_k
query, key, value = \
[l(x).reshape([B, 0, self.headers, self.d_k]).transpose([0, 2, 1, 3])
for l, x in zip(self.linears, (query, key, value))]
# 2) Apply attention on all the projected vectors in batch
x, self.attn = self_attention(
query, key, value, mask=mask, dropout=self.dropout)
x = x.transpose([0, 2, 1, 3]).reshape([B, 0, self.headers * self.d_k])
return self.linears[-1](x)
class FeedForward(nn.Layer):
def __init__(self, d_model, d_ff, dropout):
super(FeedForward, self).__init__()
self.w_1 = nn.Linear(d_model, d_ff)
self.w_2 = nn.Linear(d_ff, d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
return self.w_2(self.dropout(F.relu(self.w_1(x))))
class SubLayerConnection(nn.Layer):
"""
A residual connection followed by a layer norm.
Note for code simplicity the norm is first as opposed to last.
"""
def __init__(self, size, dropout):
super(SubLayerConnection, self).__init__()
self.norm = nn.LayerNorm(size)
self.dropout = nn.Dropout(dropout)
def forward(self, x, sublayer):
return x + self.dropout(sublayer(self.norm(x)))
def masked_fill(x, mask, value):
mask = mask.astype(x.dtype)
return x * paddle.logical_not(mask).astype(x.dtype) + mask * value
def self_attention(query, key, value, mask=None, dropout=None):
"""
Compute 'Scale Dot Product Attention'
"""
d_k = value.shape[-1]
score = paddle.matmul(query, key.transpose([0, 1, 3, 2]) / math.sqrt(d_k))
if mask is not None:
# score = score.masked_fill(mask == 0, -1e9) # b, h, L, L
score = masked_fill(score, mask == 0, -6.55e4) # for fp16
p_attn = F.softmax(score, axis=-1)
if dropout is not None:
p_attn = dropout(p_attn)
return paddle.matmul(p_attn, value), p_attn
def clones(module, N):
""" Produce N identical layers """
return nn.LayerList([copy.deepcopy(module) for _ in range(N)])
class Embeddings(nn.Layer):
def __init__(self, d_model, vocab):
super(Embeddings, self).__init__()
self.lut = nn.Embedding(vocab, d_model)
self.d_model = d_model
def forward(self, *input):
x = input[0]
return self.lut(x) * math.sqrt(self.d_model)
class PositionalEncoding(nn.Layer):
""" Implement the PE function. """
def __init__(self, d_model, dropout=0., max_len=5000):
super(PositionalEncoding, self).__init__()
self.dropout = nn.Dropout(p=dropout)
# Compute the positional encodings once in log space.
pe = paddle.zeros([max_len, d_model])
position = paddle.arange(0, max_len).unsqueeze(1).astype('float32')
div_term = paddle.exp(
paddle.arange(0, d_model, 2) * -math.log(10000.0) / d_model)
pe[:, 0::2] = paddle.sin(position * div_term)
pe[:, 1::2] = paddle.cos(position * div_term)
pe = pe.unsqueeze(0)
self.register_buffer('pe', pe)
def forward(self, feat, **kwargs):
feat = feat + self.pe[:, :paddle.shape(feat)[1]] # pe 1*5000*512
return self.dropout(feat)
...@@ -343,3 +343,46 @@ class DecayLearningRate(object): ...@@ -343,3 +343,46 @@ class DecayLearningRate(object):
power=self.factor, power=self.factor,
end_lr=self.end_lr) end_lr=self.end_lr)
return learning_rate return learning_rate
class MultiStepDecay(object):
"""
Piecewise learning rate decay
Args:
step_each_epoch(int): steps each epoch
learning_rate (float): The initial learning rate. It is a python float number.
step_size (int): the interval to update.
gamma (float, optional): The Ratio that the learning rate will be reduced. ``new_lr = origin_lr * gamma`` .
It should be less than 1.0. Default: 0.1.
last_epoch (int, optional): The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
"""
def __init__(self,
learning_rate,
milestones,
step_each_epoch,
gamma,
warmup_epoch=0,
last_epoch=-1,
**kwargs):
super(MultiStepDecay, self).__init__()
self.milestones = [step_each_epoch * e for e in milestones]
self.learning_rate = learning_rate
self.gamma = gamma
self.last_epoch = last_epoch
self.warmup_epoch = round(warmup_epoch * step_each_epoch)
def __call__(self):
learning_rate = lr.MultiStepDecay(
learning_rate=self.learning_rate,
milestones=self.milestones,
gamma=self.gamma,
last_epoch=self.last_epoch)
if self.warmup_epoch > 0:
learning_rate = lr.LinearWarmup(
learning_rate=learning_rate,
warmup_steps=self.warmup_epoch,
start_lr=0.0,
end_lr=self.learning_rate,
last_epoch=self.last_epoch)
return learning_rate
...@@ -26,12 +26,13 @@ from .east_postprocess import EASTPostProcess ...@@ -26,12 +26,13 @@ from .east_postprocess import EASTPostProcess
from .sast_postprocess import SASTPostProcess from .sast_postprocess import SASTPostProcess
from .fce_postprocess import FCEPostProcess from .fce_postprocess import FCEPostProcess
from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode, \ from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode, \
DistillationCTCLabelDecode, TableLabelDecode, NRTRLabelDecode, SARLabelDecode, \ DistillationCTCLabelDecode, NRTRLabelDecode, SARLabelDecode, \
SEEDLabelDecode, PRENLabelDecode, ViTSTRLabelDecode, ABINetLabelDecode SEEDLabelDecode, PRENLabelDecode, ViTSTRLabelDecode, ABINetLabelDecode
from .cls_postprocess import ClsPostProcess from .cls_postprocess import ClsPostProcess
from .pg_postprocess import PGPostProcess from .pg_postprocess import PGPostProcess
from .vqa_token_ser_layoutlm_postprocess import VQASerTokenLayoutLMPostProcess from .vqa_token_ser_layoutlm_postprocess import VQASerTokenLayoutLMPostProcess
from .vqa_token_re_layoutlm_postprocess import VQAReTokenLayoutLMPostProcess from .vqa_token_re_layoutlm_postprocess import VQAReTokenLayoutLMPostProcess
from .table_postprocess import TableMasterLabelDecode, TableLabelDecode
def build_post_process(config, global_config=None): def build_post_process(config, global_config=None):
...@@ -42,7 +43,8 @@ def build_post_process(config, global_config=None): ...@@ -42,7 +43,8 @@ def build_post_process(config, global_config=None):
'DistillationDBPostProcess', 'NRTRLabelDecode', 'SARLabelDecode', 'DistillationDBPostProcess', 'NRTRLabelDecode', 'SARLabelDecode',
'SEEDLabelDecode', 'VQASerTokenLayoutLMPostProcess', 'SEEDLabelDecode', 'VQASerTokenLayoutLMPostProcess',
'VQAReTokenLayoutLMPostProcess', 'PRENLabelDecode', 'VQAReTokenLayoutLMPostProcess', 'PRENLabelDecode',
'DistillationSARLabelDecode', 'ViTSTRLabelDecode', 'ABINetLabelDecode' 'DistillationSARLabelDecode', 'ViTSTRLabelDecode', 'ABINetLabelDecode',
'TableMasterLabelDecode'
] ]
if config['name'] == 'PSEPostProcess': if config['name'] == 'PSEPostProcess':
......
...@@ -38,6 +38,7 @@ class DBPostProcess(object): ...@@ -38,6 +38,7 @@ class DBPostProcess(object):
unclip_ratio=2.0, unclip_ratio=2.0,
use_dilation=False, use_dilation=False,
score_mode="fast", score_mode="fast",
use_polygon=False,
**kwargs): **kwargs):
self.thresh = thresh self.thresh = thresh
self.box_thresh = box_thresh self.box_thresh = box_thresh
...@@ -45,6 +46,7 @@ class DBPostProcess(object): ...@@ -45,6 +46,7 @@ class DBPostProcess(object):
self.unclip_ratio = unclip_ratio self.unclip_ratio = unclip_ratio
self.min_size = 3 self.min_size = 3
self.score_mode = score_mode self.score_mode = score_mode
self.use_polygon = use_polygon
assert score_mode in [ assert score_mode in [
"slow", "fast" "slow", "fast"
], "Score mode must be in [slow, fast] but got: {}".format(score_mode) ], "Score mode must be in [slow, fast] but got: {}".format(score_mode)
...@@ -52,6 +54,53 @@ class DBPostProcess(object): ...@@ -52,6 +54,53 @@ class DBPostProcess(object):
self.dilation_kernel = None if not use_dilation else np.array( self.dilation_kernel = None if not use_dilation else np.array(
[[1, 1], [1, 1]]) [[1, 1], [1, 1]])
def polygons_from_bitmap(self, pred, _bitmap, dest_width, dest_height):
'''
_bitmap: single map with shape (1, H, W),
whose values are binarized as {0, 1}
'''
bitmap = _bitmap
height, width = bitmap.shape
boxes = []
scores = []
contours, _ = cv2.findContours((bitmap * 255).astype(np.uint8),
cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE)
for contour in contours[:self.max_candidates]:
epsilon = 0.002 * cv2.arcLength(contour, True)
approx = cv2.approxPolyDP(contour, epsilon, True)
points = approx.reshape((-1, 2))
if points.shape[0] < 4:
continue
score = self.box_score_fast(pred, points.reshape(-1, 2))
if self.box_thresh > score:
continue
if points.shape[0] > 2:
box = self.unclip(points, self.unclip_ratio)
if len(box) > 1:
continue
else:
continue
box = box.reshape(-1, 2)
_, sside = self.get_mini_boxes(box.reshape((-1, 1, 2)))
if sside < self.min_size + 2:
continue
box = np.array(box)
box[:, 0] = np.clip(
np.round(box[:, 0] / width * dest_width), 0, dest_width)
box[:, 1] = np.clip(
np.round(box[:, 1] / height * dest_height), 0, dest_height)
boxes.append(box.tolist())
scores.append(score)
return boxes, scores
def boxes_from_bitmap(self, pred, _bitmap, dest_width, dest_height): def boxes_from_bitmap(self, pred, _bitmap, dest_width, dest_height):
''' '''
_bitmap: single map with shape (1, H, W), _bitmap: single map with shape (1, H, W),
...@@ -85,7 +134,7 @@ class DBPostProcess(object): ...@@ -85,7 +134,7 @@ class DBPostProcess(object):
if self.box_thresh > score: if self.box_thresh > score:
continue continue
box = self.unclip(points).reshape(-1, 1, 2) box = self.unclip(points, self.unclip_ratio).reshape(-1, 1, 2)
box, sside = self.get_mini_boxes(box) box, sside = self.get_mini_boxes(box)
if sside < self.min_size + 2: if sside < self.min_size + 2:
continue continue
...@@ -99,8 +148,7 @@ class DBPostProcess(object): ...@@ -99,8 +148,7 @@ class DBPostProcess(object):
scores.append(score) scores.append(score)
return np.array(boxes, dtype=np.int16), scores return np.array(boxes, dtype=np.int16), scores
def unclip(self, box): def unclip(self, box, unclip_ratio):
unclip_ratio = self.unclip_ratio
poly = Polygon(box) poly = Polygon(box)
distance = poly.area * unclip_ratio / poly.length distance = poly.area * unclip_ratio / poly.length
offset = pyclipper.PyclipperOffset() offset = pyclipper.PyclipperOffset()
...@@ -185,8 +233,12 @@ class DBPostProcess(object): ...@@ -185,8 +233,12 @@ class DBPostProcess(object):
self.dilation_kernel) self.dilation_kernel)
else: else:
mask = segmentation[batch_index] mask = segmentation[batch_index]
boxes, scores = self.boxes_from_bitmap(pred[batch_index], mask, if self.use_polygon is True:
src_w, src_h) boxes, scores = self.polygons_from_bitmap(pred[batch_index],
mask, src_w, src_h)
else:
boxes, scores = self.boxes_from_bitmap(pred[batch_index], mask,
src_w, src_h)
boxes_batch.append({'points': boxes}) boxes_batch.append({'points': boxes})
return boxes_batch return boxes_batch
...@@ -202,6 +254,7 @@ class DistillationDBPostProcess(object): ...@@ -202,6 +254,7 @@ class DistillationDBPostProcess(object):
unclip_ratio=1.5, unclip_ratio=1.5,
use_dilation=False, use_dilation=False,
score_mode="fast", score_mode="fast",
use_polygon=False,
**kwargs): **kwargs):
self.model_name = model_name self.model_name = model_name
self.key = key self.key = key
...@@ -211,7 +264,8 @@ class DistillationDBPostProcess(object): ...@@ -211,7 +264,8 @@ class DistillationDBPostProcess(object):
max_candidates=max_candidates, max_candidates=max_candidates,
unclip_ratio=unclip_ratio, unclip_ratio=unclip_ratio,
use_dilation=use_dilation, use_dilation=use_dilation,
score_mode=score_mode) score_mode=score_mode,
use_polygon=use_polygon)
def __call__(self, predicts, shape_list): def __call__(self, predicts, shape_list):
results = {} results = {}
......
...@@ -58,6 +58,8 @@ class PSEPostProcess(object): ...@@ -58,6 +58,8 @@ class PSEPostProcess(object):
kernels = (pred > self.thresh).astype('float32') kernels = (pred > self.thresh).astype('float32')
text_mask = kernels[:, 0, :, :] text_mask = kernels[:, 0, :, :]
text_mask = paddle.unsqueeze(text_mask, axis=1)
kernels[:, 0:, :, :] = kernels[:, 0:, :, :] * text_mask kernels[:, 0:, :, :] = kernels[:, 0:, :, :] * text_mask
score = score.numpy() score = score.numpy()
......
...@@ -380,146 +380,6 @@ class SRNLabelDecode(BaseRecLabelDecode): ...@@ -380,146 +380,6 @@ class SRNLabelDecode(BaseRecLabelDecode):
return idx return idx
class TableLabelDecode(object):
""" """
def __init__(self, character_dict_path, **kwargs):
list_character, list_elem = self.load_char_elem_dict(
character_dict_path)
list_character = self.add_special_char(list_character)
list_elem = self.add_special_char(list_elem)
self.dict_character = {}
self.dict_idx_character = {}
for i, char in enumerate(list_character):
self.dict_idx_character[i] = char
self.dict_character[char] = i
self.dict_elem = {}
self.dict_idx_elem = {}
for i, elem in enumerate(list_elem):
self.dict_idx_elem[i] = elem
self.dict_elem[elem] = i
def load_char_elem_dict(self, character_dict_path):
list_character = []
list_elem = []
with open(character_dict_path, "rb") as fin:
lines = fin.readlines()
substr = lines[0].decode('utf-8').strip("\n").strip("\r\n").split(
"\t")
character_num = int(substr[0])
elem_num = int(substr[1])
for cno in range(1, 1 + character_num):
character = lines[cno].decode('utf-8').strip("\n").strip("\r\n")
list_character.append(character)
for eno in range(1 + character_num, 1 + character_num + elem_num):
elem = lines[eno].decode('utf-8').strip("\n").strip("\r\n")
list_elem.append(elem)
return list_character, list_elem
def add_special_char(self, list_character):
self.beg_str = "sos"
self.end_str = "eos"
list_character = [self.beg_str] + list_character + [self.end_str]
return list_character
def __call__(self, preds):
structure_probs = preds['structure_probs']
loc_preds = preds['loc_preds']
if isinstance(structure_probs, paddle.Tensor):
structure_probs = structure_probs.numpy()
if isinstance(loc_preds, paddle.Tensor):
loc_preds = loc_preds.numpy()
structure_idx = structure_probs.argmax(axis=2)
structure_probs = structure_probs.max(axis=2)
structure_str, structure_pos, result_score_list, result_elem_idx_list = self.decode(
structure_idx, structure_probs, 'elem')
res_html_code_list = []
res_loc_list = []
batch_num = len(structure_str)
for bno in range(batch_num):
res_loc = []
for sno in range(len(structure_str[bno])):
text = structure_str[bno][sno]
if text in ['<td>', '<td']:
pos = structure_pos[bno][sno]
res_loc.append(loc_preds[bno, pos])
res_html_code = ''.join(structure_str[bno])
res_loc = np.array(res_loc)
res_html_code_list.append(res_html_code)
res_loc_list.append(res_loc)
return {
'res_html_code': res_html_code_list,
'res_loc': res_loc_list,
'res_score_list': result_score_list,
'res_elem_idx_list': result_elem_idx_list,
'structure_str_list': structure_str
}
def decode(self, text_index, structure_probs, char_or_elem):
"""convert text-label into text-index.
"""
if char_or_elem == "char":
current_dict = self.dict_idx_character
else:
current_dict = self.dict_idx_elem
ignored_tokens = self.get_ignored_tokens('elem')
beg_idx, end_idx = ignored_tokens
result_list = []
result_pos_list = []
result_score_list = []
result_elem_idx_list = []
batch_size = len(text_index)
for batch_idx in range(batch_size):
char_list = []
elem_pos_list = []
elem_idx_list = []
score_list = []
for idx in range(len(text_index[batch_idx])):
tmp_elem_idx = int(text_index[batch_idx][idx])
if idx > 0 and tmp_elem_idx == end_idx:
break
if tmp_elem_idx in ignored_tokens:
continue
char_list.append(current_dict[tmp_elem_idx])
elem_pos_list.append(idx)
score_list.append(structure_probs[batch_idx, idx])
elem_idx_list.append(tmp_elem_idx)
result_list.append(char_list)
result_pos_list.append(elem_pos_list)
result_score_list.append(score_list)
result_elem_idx_list.append(elem_idx_list)
return result_list, result_pos_list, result_score_list, result_elem_idx_list
def get_ignored_tokens(self, char_or_elem):
beg_idx = self.get_beg_end_flag_idx("beg", char_or_elem)
end_idx = self.get_beg_end_flag_idx("end", char_or_elem)
return [beg_idx, end_idx]
def get_beg_end_flag_idx(self, beg_or_end, char_or_elem):
if char_or_elem == "char":
if beg_or_end == "beg":
idx = self.dict_character[self.beg_str]
elif beg_or_end == "end":
idx = self.dict_character[self.end_str]
else:
assert False, "Unsupport type %s in get_beg_end_flag_idx of char" \
% beg_or_end
elif char_or_elem == "elem":
if beg_or_end == "beg":
idx = self.dict_elem[self.beg_str]
elif beg_or_end == "end":
idx = self.dict_elem[self.end_str]
else:
assert False, "Unsupport type %s in get_beg_end_flag_idx of elem" \
% beg_or_end
else:
assert False, "Unsupport type %s in char_or_elem" \
% char_or_elem
return idx
class SARLabelDecode(BaseRecLabelDecode): class SARLabelDecode(BaseRecLabelDecode):
""" Convert between text-label and text-index """ """ Convert between text-label and text-index """
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import paddle
from .rec_postprocess import AttnLabelDecode
class TableLabelDecode(AttnLabelDecode):
""" """
def __init__(self, character_dict_path, **kwargs):
super(TableLabelDecode, self).__init__(character_dict_path)
self.td_token = ['<td>', '<td', '<eb></eb>', '<td></td>']
def __call__(self, preds, batch=None):
structure_probs = preds['structure_probs']
bbox_preds = preds['loc_preds']
if isinstance(structure_probs, paddle.Tensor):
structure_probs = structure_probs.numpy()
if isinstance(bbox_preds, paddle.Tensor):
bbox_preds = bbox_preds.numpy()
shape_list = batch[-1]
result = self.decode(structure_probs, bbox_preds, shape_list)
if len(batch) == 1: # only contains shape
return result
label_decode_result = self.decode_label(batch)
return result, label_decode_result
def decode(self, structure_probs, bbox_preds, shape_list):
"""convert text-label into text-index.
"""
ignored_tokens = self.get_ignored_tokens()
end_idx = self.dict[self.end_str]
structure_idx = structure_probs.argmax(axis=2)
structure_probs = structure_probs.max(axis=2)
structure_batch_list = []
bbox_batch_list = []
batch_size = len(structure_idx)
for batch_idx in range(batch_size):
structure_list = []
bbox_list = []
score_list = []
for idx in range(len(structure_idx[batch_idx])):
char_idx = int(structure_idx[batch_idx][idx])
if idx > 0 and char_idx == end_idx:
break
if char_idx in ignored_tokens:
continue
text = self.character[char_idx]
if text in self.td_token:
bbox = bbox_preds[batch_idx, idx]
bbox = self._bbox_decode(bbox, shape_list[batch_idx])
bbox_list.append(bbox)
structure_list.append(text)
score_list.append(structure_probs[batch_idx, idx])
structure_batch_list.append([structure_list, np.mean(score_list)])
bbox_batch_list.append(np.array(bbox_list))
result = {
'bbox_batch_list': bbox_batch_list,
'structure_batch_list': structure_batch_list,
}
return result
def decode_label(self, batch):
"""convert text-label into text-index.
"""
structure_idx = batch[1]
gt_bbox_list = batch[2]
shape_list = batch[-1]
ignored_tokens = self.get_ignored_tokens()
end_idx = self.dict[self.end_str]
structure_batch_list = []
bbox_batch_list = []
batch_size = len(structure_idx)
for batch_idx in range(batch_size):
structure_list = []
bbox_list = []
for idx in range(len(structure_idx[batch_idx])):
char_idx = int(structure_idx[batch_idx][idx])
if idx > 0 and char_idx == end_idx:
break
if char_idx in ignored_tokens:
continue
structure_list.append(self.character[char_idx])
bbox = gt_bbox_list[batch_idx][idx]
if bbox.sum() != 0:
bbox = self._bbox_decode(bbox, shape_list[batch_idx])
bbox_list.append(bbox)
structure_batch_list.append(structure_list)
bbox_batch_list.append(bbox_list)
result = {
'bbox_batch_list': bbox_batch_list,
'structure_batch_list': structure_batch_list,
}
return result
def _bbox_decode(self, bbox, shape):
h, w, ratio_h, ratio_w, pad_h, pad_w = shape
src_h = h / ratio_h
src_w = w / ratio_w
bbox[0::2] *= src_w
bbox[1::2] *= src_h
return bbox
class TableMasterLabelDecode(TableLabelDecode):
""" """
def __init__(self, character_dict_path, box_shape='ori', **kwargs):
super(TableMasterLabelDecode, self).__init__(character_dict_path)
self.box_shape = box_shape
assert box_shape in [
'ori', 'pad'
], 'The shape used for box normalization must be ori or pad'
def add_special_char(self, dict_character):
self.beg_str = '<SOS>'
self.end_str = '<EOS>'
self.unknown_str = '<UKN>'
self.pad_str = '<PAD>'
dict_character = dict_character
dict_character = dict_character + [
self.unknown_str, self.beg_str, self.end_str, self.pad_str
]
return dict_character
def get_ignored_tokens(self):
pad_idx = self.dict[self.pad_str]
start_idx = self.dict[self.beg_str]
end_idx = self.dict[self.end_str]
unknown_idx = self.dict[self.unknown_str]
return [start_idx, end_idx, pad_idx, unknown_idx]
def _bbox_decode(self, bbox, shape):
h, w, ratio_h, ratio_w, pad_h, pad_w = shape
if self.box_shape == 'pad':
h, w = pad_h, pad_w
bbox[0::2] *= w
bbox[1::2] *= h
bbox[0::2] /= ratio_w
bbox[1::2] /= ratio_h
return bbox
<thead>
<tr>
<td></td>
</tr>
</thead>
<tbody>
<eb></eb>
</tbody>
<td
colspan="5"
>
</td>
colspan="2"
colspan="3"
<eb2></eb2>
<eb1></eb1>
rowspan="2"
colspan="4"
colspan="6"
rowspan="3"
colspan="9"
colspan="10"
colspan="7"
rowspan="4"
rowspan="5"
rowspan="9"
colspan="8"
rowspan="8"
rowspan="6"
rowspan="7"
rowspan="10"
<eb3></eb3>
<eb4></eb4>
<eb5></eb5>
<eb6></eb6>
<eb7></eb7>
<eb8></eb8>
<eb9></eb9>
<eb10></eb10>
...@@ -11,6 +11,7 @@ ...@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import cv2
import os import os
import numpy as np import numpy as np
from PIL import Image, ImageDraw, ImageFont from PIL import Image, ImageDraw, ImageFont
...@@ -110,3 +111,16 @@ def draw_re_results(image, ...@@ -110,3 +111,16 @@ def draw_re_results(image,
img_new = Image.blend(image, img_new, 0.5) img_new = Image.blend(image, img_new, 0.5)
return np.array(img_new) return np.array(img_new)
def draw_rectangle(img_path, boxes, use_xywh=False):
img = cv2.imread(img_path)
img_show = img.copy()
for box in boxes.astype(int):
if use_xywh:
x, y, w, h = box
x1, y1, x2, y2 = x - w // 2, y - h // 2, x + w // 2, y + h // 2
else:
x1, y1, x2, y2 = box
cv2.rectangle(img_show, (x1, y1), (x2, y2), (255, 0, 0), 2)
return img_show
\ No newline at end of file
...@@ -35,7 +35,7 @@ ...@@ -35,7 +35,7 @@
|模型名称|模型简介|推理模型大小|下载地址| |模型名称|模型简介|推理模型大小|下载地址|
| --- | --- | --- | --- | | --- | --- | --- | --- |
|en_ppocr_mobile_v2.0_table_structure|PubLayNet数据集训练的英文表格场景的表格结构预测|18.6M|[推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_structure_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.1/table/en_ppocr_mobile_v2.0_table_structure_train.tar) | |en_ppocr_mobile_v2.0_table_structure|PubTabNet数据集训练的英文表格场景的表格结构预测|18.6M|[推理模型](https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_structure_infer.tar) / [训练模型](https://paddleocr.bj.bcebos.com/dygraph_v2.1/table/en_ppocr_mobile_v2.0_table_structure_train.tar) |
<a name="3"></a> <a name="3"></a>
## 3. VQA模型 ## 3. VQA模型
......
...@@ -25,6 +25,7 @@ def init_args(): ...@@ -25,6 +25,7 @@ def init_args():
parser.add_argument("--output", type=str, default='./output') parser.add_argument("--output", type=str, default='./output')
# params for table structure # params for table structure
parser.add_argument("--table_max_len", type=int, default=488) parser.add_argument("--table_max_len", type=int, default=488)
parser.add_argument("--table_algorithm", type=str, default='TableAttn')
parser.add_argument("--table_model_dir", type=str) parser.add_argument("--table_model_dir", type=str)
parser.add_argument( parser.add_argument(
"--table_char_dict_path", "--table_char_dict_path",
......
此差异已折叠。
...@@ -54,6 +54,6 @@ random_infer_input:[{float32,[3,640,640]}];[{float32,[3,960,960]}] ...@@ -54,6 +54,6 @@ random_infer_input:[{float32,[3,640,640]}];[{float32,[3,960,960]}]
===========================train_benchmark_params========================== ===========================train_benchmark_params==========================
batch_size:8|16 batch_size:8|16
fp_items:fp32|fp16 fp_items:fp32|fp16
epoch:2 epoch:15
--profiler_options:batch_range=[10,20];state=GPU;tracer_option=Default;profile_path=model.profile --profiler_options:batch_range=[10,20];state=GPU;tracer_option=Default;profile_path=model.profile
flags:FLAGS_eager_delete_tensor_gb=0.0;FLAGS_fraction_of_gpu_memory_to_use=0.98;FLAGS_conv_workspace_size_limit=4096 flags:FLAGS_eager_delete_tensor_gb=0.0;FLAGS_fraction_of_gpu_memory_to_use=0.98;FLAGS_conv_workspace_size_limit=4096
...@@ -54,5 +54,5 @@ random_infer_input:[{float32,[3,640,640]}];[{float32,[3,960,960]}] ...@@ -54,5 +54,5 @@ random_infer_input:[{float32,[3,640,640]}];[{float32,[3,960,960]}]
===========================train_benchmark_params========================== ===========================train_benchmark_params==========================
batch_size:8|16 batch_size:8|16
fp_items:fp32|fp16 fp_items:fp32|fp16
epoch:2 epoch:15
--profiler_options:batch_range=[10,20];state=GPU;tracer_option=Default;profile_path=model.profile --profiler_options:batch_range=[10,20];state=GPU;tracer_option=Default;profile_path=model.profile
===========================train_params=========================== ===========================train_params===========================
model_name:det_r50_db++ model_name:det_r50_db_plusplus
python:python3.7 python:python3.7
gpu_list:0|0,1 gpu_list:0|0,1
Global.use_gpu:True|True Global.use_gpu:True|True
......
...@@ -54,5 +54,5 @@ random_infer_input:[{float32,[3,640,640]}];[{float32,[3,960,960]}] ...@@ -54,5 +54,5 @@ random_infer_input:[{float32,[3,640,640]}];[{float32,[3,960,960]}]
===========================train_benchmark_params========================== ===========================train_benchmark_params==========================
batch_size:8 batch_size:8
fp_items:fp32|fp16 fp_items:fp32|fp16
epoch:2 epoch:10
--profiler_options:batch_range=[10,20];state=GPU;tracer_option=Default;profile_path=model.profile --profiler_options:batch_range=[10,20];state=GPU;tracer_option=Default;profile_path=model.profile
此差异已折叠。
===========================train_params===========================
model_name:table_master
python:python3.7
gpu_list:0|0,1
Global.use_gpu:True|True
Global.auto_cast:fp32
Global.epoch_num:lite_train_lite_infer=1|whole_train_whole_infer=17
Global.save_model_dir:./output/
Train.loader.batch_size_per_card:lite_train_lite_infer=2|whole_train_whole_infer=4
Global.pretrained_model:./pretrain_models/table_structure_tablemaster_train/best_accuracy
train_model_name:latest
train_infer_img_dir:./ppstructure/docs/table/table.jpg
null:null
##
trainer:norm_train
norm_train:tools/train.py -c test_tipc/configs/table_master/table_master.yml -o Global.print_batch_step=10
pact_train:null
fpgm_train:null
distill_train:null
null:null
null:null
##
===========================eval_params===========================
eval:null
null:null
##
===========================infer_params===========================
Global.save_inference_dir:./output/
Global.checkpoints:
norm_export:tools/export_model.py -c test_tipc/configs/table_master/table_master.yml -o
quant_export:
fpgm_export:
distill_export:null
export1:null
export2:null
##
infer_model:null
infer_export:null
infer_quant:False
inference:ppstructure/table/predict_structure.py --table_char_dict_path=./ppocr/utils/dict/table_master_structure_dict.txt --image_dir=./ppstructure/docs/table/table.jpg --output ./output/table --table_algorithm=TableMaster --table_max_len=480
--use_gpu:True|False
--enable_mkldnn:False
--cpu_threads:6
--rec_batch_num:1
--use_tensorrt:False
--precision:fp32
--table_model_dir:
--image_dir:./ppstructure/docs/table/table.jpg
null:null
--benchmark:False
null:null
===========================infer_benchmark_params==========================
random_infer_input:[{float32,[3,480,480]}]
...@@ -22,13 +22,19 @@ trainer_list=$(func_parser_value "${lines[14]}") ...@@ -22,13 +22,19 @@ trainer_list=$(func_parser_value "${lines[14]}")
if [ ${MODE} = "benchmark_train" ];then if [ ${MODE} = "benchmark_train" ];then
pip install -r requirements.txt pip install -r requirements.txt
if [[ ${model_name} =~ "det_mv3_db_v2_0" || ${model_name} =~ "det_r50_vd_east_v2_0" || ${model_name} =~ "det_r50_vd_pse_v2_0" || ${model_name} =~ "det_r18_db_v2_0" ]];then if [[ ${model_name} =~ "det_mv3_db_v2_0" || ${model_name} =~ "det_r50_vd_pse_v2_0" || ${model_name} =~ "det_r18_db_v2_0" ]];then
rm -rf ./train_data/icdar2015 rm -rf ./train_data/icdar2015
wget -nc -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/pretrained/MobileNetV3_large_x0_5_pretrained.pdparams --no-check-certificate wget -nc -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/pretrained/MobileNetV3_large_x0_5_pretrained.pdparams --no-check-certificate
wget -nc -P ./train_data/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/icdar2015.tar --no-check-certificate wget -nc -P ./train_data/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/icdar2015.tar --no-check-certificate
cd ./train_data/ && tar xf icdar2015.tar && cd ../ cd ./train_data/ && tar xf icdar2015.tar && cd ../
fi fi
if [[ ${model_name} =~ "det_r50_vd_east_v2_0" || ${model_name} =~ "det_r50_vd_pse_v2_0" ]];then if [[ ${model_name} =~ "det_r50_vd_east_v2_0" ]]; then
wget -nc -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/det_r50_vd_east_v2.0_train.tar --no-check-certificate
cd ./pretrain_models/ && tar xf det_r50_vd_east_v2.0_train.tar && cd ../
wget -nc -P ./train_data/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/icdar2015.tar --no-check-certificate
cd ./train_data/ && tar xf icdar2015.tar && cd ../
fi
if [[ ${model_name} =~ "det_r50_vd_pse_v2_0" ]];then
wget -nc -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/pretrained/ResNet50_vd_ssld_pretrained.pdparams --no-check-certificate wget -nc -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/pretrained/ResNet50_vd_ssld_pretrained.pdparams --no-check-certificate
wget -nc -P ./train_data/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/icdar2015.tar --no-check-certificate wget -nc -P ./train_data/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/test/icdar2015.tar --no-check-certificate
cd ./train_data/ && tar xf icdar2015.tar && cd ../ cd ./train_data/ && tar xf icdar2015.tar && cd ../
...@@ -59,9 +65,13 @@ if [ ${MODE} = "lite_train_lite_infer" ];then ...@@ -59,9 +65,13 @@ if [ ${MODE} = "lite_train_lite_infer" ];then
wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_rec_infer.tar --no-check-certificate wget -nc -P ./inference/ https://paddleocr.bj.bcebos.com/dygraph_v2.0/table/en_ppocr_mobile_v2.0_table_rec_infer.tar --no-check-certificate
cd ./inference/ && tar xf en_ppocr_mobile_v2.0_table_det_infer.tar && tar xf en_ppocr_mobile_v2.0_table_rec_infer.tar && cd ../ cd ./inference/ && tar xf en_ppocr_mobile_v2.0_table_det_infer.tar && tar xf en_ppocr_mobile_v2.0_table_rec_infer.tar && cd ../
fi fi
if [[ ${model_name} =~ "det_r50_db++" ]];then if [[ ${model_name} =~ "det_r50_db_plusplus" ]];then
wget -nc -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/dygraph_v2.1/en_det/ResNet50_dcn_asf_synthtext_pretrained.pdparams --no-check-certificate wget -nc -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/dygraph_v2.1/en_det/ResNet50_dcn_asf_synthtext_pretrained.pdparams --no-check-certificate
fi fi
if [ ${model_name} == "table_master" ];then
wget -nc -P ./pretrain_models/ https://paddleocr.bj.bcebos.com/ppstructure/models/tablemaster/table_structure_tablemaster_train.tar --no-check-certificate
cd ./pretrain_models/ && tar xf table_structure_tablemaster_train.tar && cd ../
fi
cd ./pretrain_models/ && tar xf det_mv3_db_v2.0_train.tar && cd ../ cd ./pretrain_models/ && tar xf det_mv3_db_v2.0_train.tar && cd ../
rm -rf ./train_data/icdar2015 rm -rf ./train_data/icdar2015
rm -rf ./train_data/ic15_data rm -rf ./train_data/ic15_data
......
...@@ -54,6 +54,7 @@ ...@@ -54,6 +54,7 @@
| NRTR |rec_mtb_nrtr | 识别 | 支持 | 多机多卡 <br> 混合精度 | - | - | | NRTR |rec_mtb_nrtr | 识别 | 支持 | 多机多卡 <br> 混合精度 | - | - |
| SAR |rec_r31_sar | 识别 | 支持 | 多机多卡 <br> 混合精度 | - | - | | SAR |rec_r31_sar | 识别 | 支持 | 多机多卡 <br> 混合精度 | - | - |
| PGNet |rec_r34_vd_none_none_ctc_v2.0 | 端到端| 支持 | 多机多卡 <br> 混合精度 | - | - | | PGNet |rec_r34_vd_none_none_ctc_v2.0 | 端到端| 支持 | 多机多卡 <br> 混合精度 | - | - |
| TableMaster |table_structure_tablemaster_train | 表格识别| 支持 | 多机多卡 <br> 混合精度 | - | - |
......
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册