未验证 提交 9e4ae9dc 编写于 作者: littletomatodonkey's avatar littletomatodonkey 提交者: GitHub

add vqa code (#7096)

* add vqa code

* add order ocr info

* rename tb-yx order

* polish configs

* add trt offline-tuning

* fix seed and remove unused configs
上级 8c4e1dcc
...@@ -6,11 +6,11 @@ Global: ...@@ -6,11 +6,11 @@ Global:
save_model_dir: ./output/re_layoutlmv2_xfund_zh save_model_dir: ./output/re_layoutlmv2_xfund_zh
save_epoch_step: 2000 save_epoch_step: 2000
# evaluation is run every 10 iterations after the 0th iteration # evaluation is run every 10 iterations after the 0th iteration
eval_batch_step: [ 0, 57 ] eval_batch_step: [ 0, 19 ]
cal_metric_during_train: False cal_metric_during_train: False
save_inference_dir: save_inference_dir:
use_visualdl: False use_visualdl: False
seed: 2048 seed: 2022
infer_img: ppstructure/docs/vqa/input/zh_val_21.jpg infer_img: ppstructure/docs/vqa/input/zh_val_21.jpg
save_res_path: ./output/re_layoutlmv2_xfund_zh/res/ save_res_path: ./output/re_layoutlmv2_xfund_zh/res/
......
Global: Global:
use_gpu: True use_gpu: True
epoch_num: &epoch_num 200 epoch_num: &epoch_num 130
log_smooth_window: 10 log_smooth_window: 10
print_batch_step: 10 print_batch_step: 10
save_model_dir: ./output/re_layoutxlm/ save_model_dir: ./output/re_layoutxlm_xfund_zh
save_epoch_step: 2000 save_epoch_step: 2000
# evaluation is run every 10 iterations after the 0th iteration # evaluation is run every 10 iterations after the 0th iteration
eval_batch_step: [ 0, 19 ] eval_batch_step: [ 0, 19 ]
...@@ -12,7 +12,7 @@ Global: ...@@ -12,7 +12,7 @@ Global:
use_visualdl: False use_visualdl: False
seed: 2022 seed: 2022
infer_img: ppstructure/docs/vqa/input/zh_val_21.jpg infer_img: ppstructure/docs/vqa/input/zh_val_21.jpg
save_res_path: ./output/re/ save_res_path: ./output/re_layoutxlm_xfund_zh/res/
Architecture: Architecture:
model_type: vqa model_type: vqa
...@@ -81,7 +81,7 @@ Train: ...@@ -81,7 +81,7 @@ Train:
loader: loader:
shuffle: True shuffle: True
drop_last: False drop_last: False
batch_size_per_card: 8 batch_size_per_card: 2
num_workers: 8 num_workers: 8
collate_fn: ListCollator collate_fn: ListCollator
......
...@@ -6,13 +6,13 @@ Global: ...@@ -6,13 +6,13 @@ Global:
save_model_dir: ./output/ser_layoutlm_xfund_zh save_model_dir: ./output/ser_layoutlm_xfund_zh
save_epoch_step: 2000 save_epoch_step: 2000
# evaluation is run every 10 iterations after the 0th iteration # evaluation is run every 10 iterations after the 0th iteration
eval_batch_step: [ 0, 57 ] eval_batch_step: [ 0, 19 ]
cal_metric_during_train: False cal_metric_during_train: False
save_inference_dir: save_inference_dir:
use_visualdl: False use_visualdl: False
seed: 2022 seed: 2022
infer_img: ppstructure/docs/vqa/input/zh_val_42.jpg infer_img: ppstructure/docs/vqa/input/zh_val_42.jpg
save_res_path: ./output/ser_layoutlm_xfund_zh/res/ save_res_path: ./output/re_layoutlm_xfund_zh/res
Architecture: Architecture:
model_type: vqa model_type: vqa
...@@ -55,6 +55,7 @@ Train: ...@@ -55,6 +55,7 @@ Train:
data_dir: train_data/XFUND/zh_train/image data_dir: train_data/XFUND/zh_train/image
label_file_list: label_file_list:
- train_data/XFUND/zh_train/train.json - train_data/XFUND/zh_train/train.json
ratio_list: [ 1.0 ]
transforms: transforms:
- DecodeImage: # load image - DecodeImage: # load image
img_mode: RGB img_mode: RGB
......
...@@ -27,6 +27,7 @@ Architecture: ...@@ -27,6 +27,7 @@ Architecture:
Loss: Loss:
name: VQASerTokenLayoutLMLoss name: VQASerTokenLayoutLMLoss
num_classes: *num_classes num_classes: *num_classes
key: "backbone_out"
Optimizer: Optimizer:
name: AdamW name: AdamW
......
...@@ -27,6 +27,7 @@ Architecture: ...@@ -27,6 +27,7 @@ Architecture:
Loss: Loss:
name: VQASerTokenLayoutLMLoss name: VQASerTokenLayoutLMLoss
num_classes: *num_classes num_classes: *num_classes
key: "backbone_out"
Optimizer: Optimizer:
name: AdamW name: AdamW
......
Global: Global:
use_gpu: True use_gpu: True
epoch_num: &epoch_num 200 epoch_num: &epoch_num 130
log_smooth_window: 10 log_smooth_window: 10
print_batch_step: 10 print_batch_step: 10
save_model_dir: ./output/re_layoutxlm_funsd save_model_dir: ./output/re_vi_layoutxlm_xfund_zh
save_epoch_step: 2000 save_epoch_step: 2000
# evaluation is run every 10 iterations after the 0th iteration # evaluation is run every 10 iterations after the 0th iteration
eval_batch_step: [ 0, 57 ] eval_batch_step: [ 0, 19 ]
cal_metric_during_train: False cal_metric_during_train: False
save_inference_dir: save_inference_dir:
use_visualdl: False use_visualdl: False
seed: 2022 seed: 2022
infer_img: train_data/FUNSD/testing_data/images/83624198.png infer_img: ppstructure/docs/vqa/input/zh_val_21.jpg
save_res_path: ./output/re_layoutxlm_funsd/res/ save_res_path: ./output/re/xfund_zh/with_gt
Architecture: Architecture:
model_type: vqa model_type: vqa
...@@ -21,6 +21,7 @@ Architecture: ...@@ -21,6 +21,7 @@ Architecture:
Backbone: Backbone:
name: LayoutXLMForRe name: LayoutXLMForRe
pretrained: True pretrained: True
mode: vi
checkpoints: checkpoints:
Loss: Loss:
...@@ -50,10 +51,9 @@ Metric: ...@@ -50,10 +51,9 @@ Metric:
Train: Train:
dataset: dataset:
name: SimpleDataSet name: SimpleDataSet
data_dir: ./train_data/FUNSD/training_data/images/ data_dir: train_data/XFUND/zh_train/image
label_file_list: label_file_list:
- ./train_data/FUNSD/train_v4.json - train_data/XFUND/zh_train/train.json
# - ./train_data/FUNSD/train.json
ratio_list: [ 1.0 ] ratio_list: [ 1.0 ]
transforms: transforms:
- DecodeImage: # load image - DecodeImage: # load image
...@@ -62,8 +62,9 @@ Train: ...@@ -62,8 +62,9 @@ Train:
- VQATokenLabelEncode: # Class handling label - VQATokenLabelEncode: # Class handling label
contains_re: True contains_re: True
algorithm: *algorithm algorithm: *algorithm
class_path: &class_path ./train_data/FUNSD/class_list.txt class_path: &class_path train_data/XFUND/class_list_xfun.txt
use_textline_bbox_info: &use_textline_bbox_info True use_textline_bbox_info: &use_textline_bbox_info True
order_method: &order_method "tb-yx"
- VQATokenPad: - VQATokenPad:
max_seq_len: &max_seq_len 512 max_seq_len: &max_seq_len 512
return_attention_mask: True return_attention_mask: True
...@@ -79,22 +80,20 @@ Train: ...@@ -79,22 +80,20 @@ Train:
order: 'hwc' order: 'hwc'
- ToCHWImage: - ToCHWImage:
- KeepKeys: - KeepKeys:
# dataloader will return list in this order keep_keys: [ 'input_ids', 'bbox','attention_mask', 'token_type_ids', 'image', 'entities', 'relations'] # dataloader will return list in this order
keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image', 'entities', 'relations']
loader: loader:
shuffle: False shuffle: True
drop_last: False drop_last: False
batch_size_per_card: 8 batch_size_per_card: 2
num_workers: 16 num_workers: 4
collate_fn: ListCollator collate_fn: ListCollator
Eval: Eval:
dataset: dataset:
name: SimpleDataSet name: SimpleDataSet
data_dir: ./train_data/FUNSD/testing_data/images/ data_dir: train_data/XFUND/zh_val/image
label_file_list: label_file_list:
- ./train_data/FUNSD/test_v4.json - train_data/XFUND/zh_val/val.json
# - ./train_data/FUNSD/test.json
transforms: transforms:
- DecodeImage: # load image - DecodeImage: # load image
img_mode: RGB img_mode: RGB
...@@ -104,6 +103,7 @@ Eval: ...@@ -104,6 +103,7 @@ Eval:
algorithm: *algorithm algorithm: *algorithm
class_path: *class_path class_path: *class_path
use_textline_bbox_info: *use_textline_bbox_info use_textline_bbox_info: *use_textline_bbox_info
order_method: *order_method
- VQATokenPad: - VQATokenPad:
max_seq_len: *max_seq_len max_seq_len: *max_seq_len
return_attention_mask: True return_attention_mask: True
...@@ -119,11 +119,11 @@ Eval: ...@@ -119,11 +119,11 @@ Eval:
order: 'hwc' order: 'hwc'
- ToCHWImage: - ToCHWImage:
- KeepKeys: - KeepKeys:
# dataloader will return list in this order keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image', 'entities', 'relations'] # dataloader will return list in this order
keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image', 'entities', 'relations']
loader: loader:
shuffle: False shuffle: False
drop_last: False drop_last: False
batch_size_per_card: 8 batch_size_per_card: 8
num_workers: 8 num_workers: 8
collate_fn: ListCollator collate_fn: ListCollator
Global:
use_gpu: True
epoch_num: &epoch_num 130
log_smooth_window: 10
print_batch_step: 10
save_model_dir: ./output/re_vi_layoutxlm_xfund_zh_udml
save_epoch_step: 2000
# evaluation is run every 10 iterations after the 0th iteration
eval_batch_step: [ 0, 19 ]
cal_metric_during_train: False
save_inference_dir:
use_visualdl: False
seed: 2022
infer_img: ppstructure/docs/vqa/input/zh_val_21.jpg
save_res_path: ./output/re/xfund_zh/with_gt
Architecture:
model_type: &model_type "vqa"
name: DistillationModel
algorithm: Distillation
Models:
Teacher:
pretrained:
freeze_params: false
return_all_feats: true
model_type: *model_type
algorithm: &algorithm "LayoutXLM"
Transform:
Backbone:
name: LayoutXLMForRe
pretrained: True
mode: vi
checkpoints:
Student:
pretrained:
freeze_params: false
return_all_feats: true
model_type: *model_type
algorithm: *algorithm
Transform:
Backbone:
name: LayoutXLMForRe
pretrained: True
mode: vi
checkpoints:
Loss:
name: CombinedLoss
loss_config_list:
- DistillationLossFromOutput:
weight: 1.0
model_name_list: ["Student", "Teacher"]
key: loss
reduction: mean
- DistillationVQADistanceLoss:
weight: 0.5
mode: "l2"
model_name_pairs:
- ["Student", "Teacher"]
key: hidden_states_5
name: "loss_5"
- DistillationVQADistanceLoss:
weight: 0.5
mode: "l2"
model_name_pairs:
- ["Student", "Teacher"]
key: hidden_states_8
name: "loss_8"
Optimizer:
name: AdamW
beta1: 0.9
beta2: 0.999
clip_norm: 10
lr:
learning_rate: 0.00005
warmup_epoch: 10
regularizer:
name: L2
factor: 0.00000
PostProcess:
name: DistillationRePostProcess
model_name: ["Student", "Teacher"]
key: null
Metric:
name: DistillationMetric
base_metric_name: VQAReTokenMetric
main_indicator: hmean
key: "Student"
Train:
dataset:
name: SimpleDataSet
data_dir: train_data/XFUND/zh_train/image
label_file_list:
- train_data/XFUND/zh_train/train.json
ratio_list: [ 1.0 ]
transforms:
- DecodeImage: # load image
img_mode: RGB
channel_first: False
- VQATokenLabelEncode: # Class handling label
contains_re: True
algorithm: *algorithm
class_path: &class_path train_data/XFUND/class_list_xfun.txt
use_textline_bbox_info: &use_textline_bbox_info True
# [None, "tb-yx"]
order_method: &order_method "tb-yx"
- VQATokenPad:
max_seq_len: &max_seq_len 512
return_attention_mask: True
- VQAReTokenRelation:
- VQAReTokenChunk:
max_seq_len: *max_seq_len
- Resize:
size: [224,224]
- NormalizeImage:
scale: 1
mean: [ 123.675, 116.28, 103.53 ]
std: [ 58.395, 57.12, 57.375 ]
order: 'hwc'
- ToCHWImage:
- KeepKeys:
keep_keys: [ 'input_ids', 'bbox','attention_mask', 'token_type_ids', 'image', 'entities', 'relations'] # dataloader will return list in this order
loader:
shuffle: True
drop_last: False
batch_size_per_card: 2
num_workers: 4
collate_fn: ListCollator
Eval:
dataset:
name: SimpleDataSet
data_dir: train_data/XFUND/zh_val/image
label_file_list:
- train_data/XFUND/zh_val/val.json
transforms:
- DecodeImage: # load image
img_mode: RGB
channel_first: False
- VQATokenLabelEncode: # Class handling label
contains_re: True
algorithm: *algorithm
class_path: *class_path
use_textline_bbox_info: *use_textline_bbox_info
order_method: *order_method
- VQATokenPad:
max_seq_len: *max_seq_len
return_attention_mask: True
- VQAReTokenRelation:
- VQAReTokenChunk:
max_seq_len: *max_seq_len
- Resize:
size: [224,224]
- NormalizeImage:
scale: 1
mean: [ 123.675, 116.28, 103.53 ]
std: [ 58.395, 57.12, 57.375 ]
order: 'hwc'
- ToCHWImage:
- KeepKeys:
keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image', 'entities', 'relations'] # dataloader will return list in this order
loader:
shuffle: False
drop_last: False
batch_size_per_card: 8
num_workers: 8
collate_fn: ListCollator
...@@ -3,30 +3,38 @@ Global: ...@@ -3,30 +3,38 @@ Global:
epoch_num: &epoch_num 200 epoch_num: &epoch_num 200
log_smooth_window: 10 log_smooth_window: 10
print_batch_step: 10 print_batch_step: 10
save_model_dir: ./output/ser_layoutlm_funsd save_model_dir: ./output/ser_vi_layoutxlm_xfund_zh
save_epoch_step: 2000 save_epoch_step: 2000
# evaluation is run every 10 iterations after the 0th iteration # evaluation is run every 10 iterations after the 0th iteration
eval_batch_step: [ 0, 57 ] eval_batch_step: [ 0, 19 ]
cal_metric_during_train: False cal_metric_during_train: False
save_inference_dir: save_inference_dir:
use_visualdl: False use_visualdl: False
seed: 2022 seed: 2022
infer_img: train_data/FUNSD/testing_data/images/83624198.png infer_img: ppstructure/docs/vqa/input/zh_val_42.jpg
save_res_path: ./output/ser_layoutlm_funsd/res/ # if you want to predict using the groundtruth ocr info,
# you can use the following config
# infer_img: train_data/XFUND/zh_val/val.json
# infer_mode: False
save_res_path: ./output/ser/xfund_zh/res
Architecture: Architecture:
model_type: vqa model_type: vqa
algorithm: &algorithm "LayoutLM" algorithm: &algorithm "LayoutXLM"
Transform: Transform:
Backbone: Backbone:
name: LayoutLMForSer name: LayoutXLMForSer
pretrained: True pretrained: True
checkpoints: checkpoints:
# one of base or vi
mode: vi
num_classes: &num_classes 7 num_classes: &num_classes 7
Loss: Loss:
name: VQASerTokenLayoutLMLoss name: VQASerTokenLayoutLMLoss
num_classes: *num_classes num_classes: *num_classes
key: "backbone_out"
Optimizer: Optimizer:
name: AdamW name: AdamW
...@@ -43,7 +51,7 @@ Optimizer: ...@@ -43,7 +51,7 @@ Optimizer:
PostProcess: PostProcess:
name: VQASerTokenLayoutLMPostProcess name: VQASerTokenLayoutLMPostProcess
class_path: &class_path ./train_data/FUNSD/class_list.txt class_path: &class_path train_data/XFUND/class_list_xfun.txt
Metric: Metric:
name: VQASerTokenMetric name: VQASerTokenMetric
...@@ -52,9 +60,10 @@ Metric: ...@@ -52,9 +60,10 @@ Metric:
Train: Train:
dataset: dataset:
name: SimpleDataSet name: SimpleDataSet
data_dir: ./train_data/FUNSD/training_data/images/ data_dir: train_data/XFUND/zh_train/image
label_file_list: label_file_list:
- ./train_data/FUNSD/train.json - train_data/XFUND/zh_train/train.json
ratio_list: [ 1.0 ]
transforms: transforms:
- DecodeImage: # load image - DecodeImage: # load image
img_mode: RGB img_mode: RGB
...@@ -64,6 +73,8 @@ Train: ...@@ -64,6 +73,8 @@ Train:
algorithm: *algorithm algorithm: *algorithm
class_path: *class_path class_path: *class_path
use_textline_bbox_info: &use_textline_bbox_info True use_textline_bbox_info: &use_textline_bbox_info True
# one of [None, "tb-yx"]
order_method: &order_method "tb-yx"
- VQATokenPad: - VQATokenPad:
max_seq_len: &max_seq_len 512 max_seq_len: &max_seq_len 512
return_attention_mask: True return_attention_mask: True
...@@ -78,8 +89,7 @@ Train: ...@@ -78,8 +89,7 @@ Train:
order: 'hwc' order: 'hwc'
- ToCHWImage: - ToCHWImage:
- KeepKeys: - KeepKeys:
# dataloader will return list in this order keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image', 'labels'] # dataloader will return list in this order
keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image', 'labels']
loader: loader:
shuffle: True shuffle: True
drop_last: False drop_last: False
...@@ -89,9 +99,9 @@ Train: ...@@ -89,9 +99,9 @@ Train:
Eval: Eval:
dataset: dataset:
name: SimpleDataSet name: SimpleDataSet
data_dir: train_data/FUNSD/testing_data/images/ data_dir: train_data/XFUND/zh_val/image
label_file_list: label_file_list:
- ./train_data/FUNSD/test.json - train_data/XFUND/zh_val/val.json
transforms: transforms:
- DecodeImage: # load image - DecodeImage: # load image
img_mode: RGB img_mode: RGB
...@@ -101,6 +111,7 @@ Eval: ...@@ -101,6 +111,7 @@ Eval:
algorithm: *algorithm algorithm: *algorithm
class_path: *class_path class_path: *class_path
use_textline_bbox_info: *use_textline_bbox_info use_textline_bbox_info: *use_textline_bbox_info
order_method: *order_method
- VQATokenPad: - VQATokenPad:
max_seq_len: *max_seq_len max_seq_len: *max_seq_len
return_attention_mask: True return_attention_mask: True
...@@ -115,8 +126,7 @@ Eval: ...@@ -115,8 +126,7 @@ Eval:
order: 'hwc' order: 'hwc'
- ToCHWImage: - ToCHWImage:
- KeepKeys: - KeepKeys:
# dataloader will return list in this order keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image', 'labels'] # dataloader will return list in this order
keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image', 'labels']
loader: loader:
shuffle: False shuffle: False
drop_last: False drop_last: False
......
...@@ -3,30 +3,84 @@ Global: ...@@ -3,30 +3,84 @@ Global:
epoch_num: &epoch_num 200 epoch_num: &epoch_num 200
log_smooth_window: 10 log_smooth_window: 10
print_batch_step: 10 print_batch_step: 10
save_model_dir: ./output/ser_layoutxlm_funsd save_model_dir: ./output/ser_vi_layoutxlm_xfund_zh_udml
save_epoch_step: 2000 save_epoch_step: 2000
# evaluation is run every 10 iterations after the 0th iteration # evaluation is run every 10 iterations after the 0th iteration
eval_batch_step: [ 0, 57 ] eval_batch_step: [ 0, 19 ]
cal_metric_during_train: False cal_metric_during_train: False
save_inference_dir: save_inference_dir:
use_visualdl: False use_visualdl: False
seed: 2022 seed: 2022
infer_img: train_data/FUNSD/testing_data/images/83624198.png infer_img: ppstructure/docs/vqa/input/zh_val_42.jpg
save_res_path: output/ser_layoutxlm_funsd/res/ save_res_path: ./output/ser_layoutxlm_xfund_zh/res
Architecture: Architecture:
model_type: vqa model_type: &model_type "vqa"
algorithm: &algorithm "LayoutXLM" name: DistillationModel
Transform: algorithm: Distillation
Backbone: Models:
name: LayoutXLMForSer Teacher:
pretrained: True pretrained:
checkpoints: freeze_params: false
num_classes: &num_classes 7 return_all_feats: true
model_type: *model_type
algorithm: &algorithm "LayoutXLM"
Transform:
Backbone:
name: LayoutXLMForSer
pretrained: True
# one of base or vi
mode: vi
checkpoints:
num_classes: &num_classes 7
Student:
pretrained:
freeze_params: false
return_all_feats: true
model_type: *model_type
algorithm: *algorithm
Transform:
Backbone:
name: LayoutXLMForSer
pretrained: True
# one of base or vi
mode: vi
checkpoints:
num_classes: *num_classes
Loss: Loss:
name: VQASerTokenLayoutLMLoss name: CombinedLoss
num_classes: *num_classes loss_config_list:
- DistillationVQASerTokenLayoutLMLoss:
weight: 1.0
model_name_list: ["Student", "Teacher"]
key: backbone_out
num_classes: *num_classes
- DistillationSERDMLLoss:
weight: 1.0
act: "softmax"
use_log: true
model_name_pairs:
- ["Student", "Teacher"]
key: backbone_out
- DistillationVQADistanceLoss:
weight: 0.5
mode: "l2"
model_name_pairs:
- ["Student", "Teacher"]
key: hidden_states_5
name: "loss_5"
- DistillationVQADistanceLoss:
weight: 0.5
mode: "l2"
model_name_pairs:
- ["Student", "Teacher"]
key: hidden_states_8
name: "loss_8"
Optimizer: Optimizer:
name: AdamW name: AdamW
...@@ -36,25 +90,29 @@ Optimizer: ...@@ -36,25 +90,29 @@ Optimizer:
name: Linear name: Linear
learning_rate: 0.00005 learning_rate: 0.00005
epochs: *epoch_num epochs: *epoch_num
warmup_epoch: 2 warmup_epoch: 10
regularizer: regularizer:
name: L2 name: L2
factor: 0.00000 factor: 0.00000
PostProcess: PostProcess:
name: VQASerTokenLayoutLMPostProcess name: DistillationSerPostProcess
class_path: &class_path ./train_data/FUNSD/class_list.txt model_name: ["Student", "Teacher"]
key: backbone_out
class_path: &class_path train_data/XFUND/class_list_xfun.txt
Metric: Metric:
name: VQASerTokenMetric name: DistillationMetric
base_metric_name: VQASerTokenMetric
main_indicator: hmean main_indicator: hmean
key: "Student"
Train: Train:
dataset: dataset:
name: SimpleDataSet name: SimpleDataSet
data_dir: ./train_data/FUNSD/training_data/images/ data_dir: train_data/XFUND/zh_train/image
label_file_list: label_file_list:
- ./train_data/FUNSD/train.json - train_data/XFUND/zh_train/train.json
ratio_list: [ 1.0 ] ratio_list: [ 1.0 ]
transforms: transforms:
- DecodeImage: # load image - DecodeImage: # load image
...@@ -64,6 +122,8 @@ Train: ...@@ -64,6 +122,8 @@ Train:
contains_re: False contains_re: False
algorithm: *algorithm algorithm: *algorithm
class_path: *class_path class_path: *class_path
# one of [None, "tb-yx"]
order_method: &order_method "tb-yx"
- VQATokenPad: - VQATokenPad:
max_seq_len: &max_seq_len 512 max_seq_len: &max_seq_len 512
return_attention_mask: True return_attention_mask: True
...@@ -78,20 +138,19 @@ Train: ...@@ -78,20 +138,19 @@ Train:
order: 'hwc' order: 'hwc'
- ToCHWImage: - ToCHWImage:
- KeepKeys: - KeepKeys:
# dataloader will return list in this order keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image', 'labels'] # dataloader will return list in this order
keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image', 'labels']
loader: loader:
shuffle: True shuffle: True
drop_last: False drop_last: False
batch_size_per_card: 8 batch_size_per_card: 4
num_workers: 4 num_workers: 4
Eval: Eval:
dataset: dataset:
name: SimpleDataSet name: SimpleDataSet
data_dir: train_data/FUNSD/testing_data/images/ data_dir: train_data/XFUND/zh_val/image
label_file_list: label_file_list:
- ./train_data/FUNSD/test.json - train_data/XFUND/zh_val/val.json
transforms: transforms:
- DecodeImage: # load image - DecodeImage: # load image
img_mode: RGB img_mode: RGB
...@@ -100,6 +159,7 @@ Eval: ...@@ -100,6 +159,7 @@ Eval:
contains_re: False contains_re: False
algorithm: *algorithm algorithm: *algorithm
class_path: *class_path class_path: *class_path
order_method: *order_method
- VQATokenPad: - VQATokenPad:
max_seq_len: *max_seq_len max_seq_len: *max_seq_len
return_attention_mask: True return_attention_mask: True
...@@ -114,10 +174,10 @@ Eval: ...@@ -114,10 +174,10 @@ Eval:
order: 'hwc' order: 'hwc'
- ToCHWImage: - ToCHWImage:
- KeepKeys: - KeepKeys:
# dataloader will return list in this order keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image', 'labels'] # dataloader will return list in this order
keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image', 'labels']
loader: loader:
shuffle: False shuffle: False
drop_last: False drop_last: False
batch_size_per_card: 8 batch_size_per_card: 8
num_workers: 4 num_workers: 4
Global:
use_gpu: True
epoch_num: &epoch_num 200
log_smooth_window: 10
print_batch_step: 10
save_model_dir: ./output/re_layoutlmv2_funsd
save_epoch_step: 2000
# evaluation is run every 10 iterations after the 0th iteration
eval_batch_step: [ 0, 57 ]
cal_metric_during_train: False
save_inference_dir:
use_visualdl: False
seed: 2022
infer_img: train_data/FUNSD/testing_data/images/83624198.png
save_res_path: ./output/re_layoutlmv2_funsd/res/
Architecture:
model_type: vqa
algorithm: &algorithm "LayoutLMv2"
Transform:
Backbone:
name: LayoutLMv2ForRe
pretrained: True
checkpoints:
Loss:
name: LossFromOutput
key: loss
reduction: mean
Optimizer:
name: AdamW
beta1: 0.9
beta2: 0.999
clip_norm: 10
lr:
learning_rate: 0.00005
warmup_epoch: 10
regularizer:
name: L2
factor: 0.00000
PostProcess:
name: VQAReTokenLayoutLMPostProcess
Metric:
name: VQAReTokenMetric
main_indicator: hmean
Train:
dataset:
name: SimpleDataSet
data_dir: ./train_data/FUNSD/training_data/images/
label_file_list:
- ./train_data/FUNSD/train.json
ratio_list: [ 1.0 ]
transforms:
- DecodeImage: # load image
img_mode: RGB
channel_first: False
- VQATokenLabelEncode: # Class handling label
contains_re: True
algorithm: *algorithm
class_path: &class_path train_data/FUNSD/class_list.txt
- VQATokenPad:
max_seq_len: &max_seq_len 512
return_attention_mask: True
- VQAReTokenRelation:
- VQAReTokenChunk:
max_seq_len: *max_seq_len
- Resize:
size: [224,224]
- NormalizeImage:
scale: 1./255.
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: 'hwc'
- ToCHWImage:
- KeepKeys:
# dataloader will return list in this order
keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image', 'entities', 'relations']
loader:
shuffle: True
drop_last: False
batch_size_per_card: 8
num_workers: 8
collate_fn: ListCollator
Eval:
dataset:
name: SimpleDataSet
data_dir: ./train_data/FUNSD/testing_data/images/
label_file_list:
- ./train_data/FUNSD/test.json
transforms:
- DecodeImage: # load image
img_mode: RGB
channel_first: False
- VQATokenLabelEncode: # Class handling label
contains_re: True
algorithm: *algorithm
class_path: *class_path
- VQATokenPad:
max_seq_len: *max_seq_len
return_attention_mask: True
- VQAReTokenRelation:
- VQAReTokenChunk:
max_seq_len: *max_seq_len
- Resize:
size: [224,224]
- NormalizeImage:
scale: 1./255.
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: 'hwc'
- ToCHWImage:
- KeepKeys:
# dataloader will return list in this order
keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image', 'entities', 'relations']
loader:
shuffle: False
drop_last: False
batch_size_per_card: 8
num_workers: 8
collate_fn: ListCollator
Global:
use_gpu: True
epoch_num: &epoch_num 200
log_smooth_window: 10
print_batch_step: 10
save_model_dir: ./output/ser_layoutlm_sroie
save_epoch_step: 2000
# evaluation is run every 10 iterations after the 0th iteration
eval_batch_step: [ 0, 200 ]
cal_metric_during_train: False
save_inference_dir:
use_visualdl: False
seed: 2022
infer_img: train_data/SROIE/test/X00016469670.jpg
save_res_path: ./output/ser_layoutlm_sroie/res/
Architecture:
model_type: vqa
algorithm: &algorithm "LayoutLM"
Transform:
Backbone:
name: LayoutLMForSer
pretrained: True
checkpoints:
num_classes: &num_classes 9
Loss:
name: VQASerTokenLayoutLMLoss
num_classes: *num_classes
Optimizer:
name: AdamW
beta1: 0.9
beta2: 0.999
lr:
name: Linear
learning_rate: 0.00005
epochs: *epoch_num
warmup_epoch: 2
regularizer:
name: L2
factor: 0.00000
PostProcess:
name: VQASerTokenLayoutLMPostProcess
class_path: &class_path ./train_data/SROIE/class_list.txt
Metric:
name: VQASerTokenMetric
main_indicator: hmean
Train:
dataset:
name: SimpleDataSet
data_dir: ./train_data/SROIE/train
label_file_list:
- ./train_data/SROIE/train.txt
transforms:
- DecodeImage: # load image
img_mode: RGB
channel_first: False
- VQATokenLabelEncode: # Class handling label
contains_re: False
algorithm: *algorithm
class_path: *class_path
use_textline_bbox_info: &use_textline_bbox_info True
- VQATokenPad:
max_seq_len: &max_seq_len 512
return_attention_mask: True
- VQASerTokenChunk:
max_seq_len: *max_seq_len
- Resize:
size: [224,224]
- NormalizeImage:
scale: 1
mean: [ 123.675, 116.28, 103.53 ]
std: [ 58.395, 57.12, 57.375 ]
order: 'hwc'
- ToCHWImage:
- KeepKeys:
# dataloader will return list in this order
keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image', 'labels']
loader:
shuffle: True
drop_last: False
batch_size_per_card: 8
num_workers: 4
Eval:
dataset:
name: SimpleDataSet
data_dir: ./train_data/SROIE/test
label_file_list:
- ./train_data/SROIE/test.txt
transforms:
- DecodeImage: # load image
img_mode: RGB
channel_first: False
- VQATokenLabelEncode: # Class handling label
contains_re: False
algorithm: *algorithm
class_path: *class_path
use_textline_bbox_info: *use_textline_bbox_info
- VQATokenPad:
max_seq_len: *max_seq_len
return_attention_mask: True
- VQASerTokenChunk:
max_seq_len: *max_seq_len
- Resize:
size: [224,224]
- NormalizeImage:
scale: 1
mean: [ 123.675, 116.28, 103.53 ]
std: [ 58.395, 57.12, 57.375 ]
order: 'hwc'
- ToCHWImage:
- KeepKeys:
# dataloader will return list in this order
keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image', 'labels']
loader:
shuffle: False
drop_last: False
batch_size_per_card: 8
num_workers: 4
Global:
use_gpu: True
epoch_num: &epoch_num 200
log_smooth_window: 10
print_batch_step: 10
save_model_dir: ./output/ser_layoutlmv2_funsd
save_epoch_step: 2000
# evaluation is run every 10 iterations after the 0th iteration
eval_batch_step: [ 0, 100 ]
cal_metric_during_train: False
save_inference_dir:
use_visualdl: False
seed: 2022
infer_img: train_data/FUNSD/testing_data/images/83624198.png
save_res_path: ./output/ser_layoutlmv2_funsd/res/
Architecture:
model_type: vqa
algorithm: &algorithm "LayoutLMv2"
Transform:
Backbone:
name: LayoutLMv2ForSer
pretrained: True
checkpoints:
num_classes: &num_classes 7
Loss:
name: VQASerTokenLayoutLMLoss
num_classes: *num_classes
Optimizer:
name: AdamW
beta1: 0.9
beta2: 0.999
lr:
name: Linear
learning_rate: 0.00005
epochs: *epoch_num
warmup_epoch: 2
regularizer:
name: L2
factor: 0.00000
PostProcess:
name: VQASerTokenLayoutLMPostProcess
class_path: &class_path train_data/FUNSD/class_list.txt
Metric:
name: VQASerTokenMetric
main_indicator: hmean
Train:
dataset:
name: SimpleDataSet
data_dir: ./train_data/FUNSD/training_data/images/
label_file_list:
- ./train_data/FUNSD/train.json
transforms:
- DecodeImage: # load image
img_mode: RGB
channel_first: False
- VQATokenLabelEncode: # Class handling label
contains_re: False
algorithm: *algorithm
class_path: *class_path
- VQATokenPad:
max_seq_len: &max_seq_len 512
return_attention_mask: True
- VQASerTokenChunk:
max_seq_len: *max_seq_len
- Resize:
size: [224,224]
- NormalizeImage:
scale: 1
mean: [ 123.675, 116.28, 103.53 ]
std: [ 58.395, 57.12, 57.375 ]
order: 'hwc'
- ToCHWImage:
- KeepKeys:
# dataloader will return list in this order
keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image', 'labels']
loader:
shuffle: True
drop_last: False
batch_size_per_card: 8
num_workers: 4
Eval:
dataset:
name: SimpleDataSet
data_dir: ./train_data/FUNSD/testing_data/images/
label_file_list:
- ./train_data/FUNSD/test.json
transforms:
- DecodeImage: # load image
img_mode: RGB
channel_first: False
- VQATokenLabelEncode: # Class handling label
contains_re: False
algorithm: *algorithm
class_path: *class_path
- VQATokenPad:
max_seq_len: *max_seq_len
return_attention_mask: True
- VQASerTokenChunk:
max_seq_len: *max_seq_len
- Resize:
size: [224,224]
- NormalizeImage:
scale: 1
mean: [ 123.675, 116.28, 103.53 ]
std: [ 58.395, 57.12, 57.375 ]
order: 'hwc'
- ToCHWImage:
- KeepKeys:
# dataloader will return list in this order
keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image', 'labels']
loader:
shuffle: False
drop_last: False
batch_size_per_card: 8
num_workers: 4
Global:
use_gpu: True
epoch_num: &epoch_num 200
log_smooth_window: 10
print_batch_step: 10
save_model_dir: ./output/ser_layoutlmv2_sroie
save_epoch_step: 2000
# evaluation is run every 10 iterations after the 0th iteration
eval_batch_step: [ 0, 200 ]
cal_metric_during_train: False
save_inference_dir:
use_visualdl: False
seed: 2022
infer_img: train_data/SROIE/test/X00016469670.jpg
save_res_path: ./output/ser_layoutlmv2_sroie/res/
Architecture:
model_type: vqa
algorithm: &algorithm "LayoutLMv2"
Transform:
Backbone:
name: LayoutLMv2ForSer
pretrained: True
checkpoints:
num_classes: &num_classes 9
Loss:
name: VQASerTokenLayoutLMLoss
num_classes: *num_classes
Optimizer:
name: AdamW
beta1: 0.9
beta2: 0.999
lr:
name: Linear
learning_rate: 0.00005
epochs: *epoch_num
warmup_epoch: 2
regularizer:
name: L2
factor: 0.00000
PostProcess:
name: VQASerTokenLayoutLMPostProcess
class_path: &class_path ./train_data/SROIE/class_list.txt
Metric:
name: VQASerTokenMetric
main_indicator: hmean
Train:
dataset:
name: SimpleDataSet
data_dir: ./train_data/SROIE/train
label_file_list:
- ./train_data/SROIE/train.txt
transforms:
- DecodeImage: # load image
img_mode: RGB
channel_first: False
- VQATokenLabelEncode: # Class handling label
contains_re: False
algorithm: *algorithm
class_path: *class_path
- VQATokenPad:
max_seq_len: &max_seq_len 512
return_attention_mask: True
- VQASerTokenChunk:
max_seq_len: *max_seq_len
- Resize:
size: [224,224]
- NormalizeImage:
scale: 1
mean: [ 123.675, 116.28, 103.53 ]
std: [ 58.395, 57.12, 57.375 ]
order: 'hwc'
- ToCHWImage:
- KeepKeys:
# dataloader will return list in this order
keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image', 'labels']
loader:
shuffle: True
drop_last: False
batch_size_per_card: 8
num_workers: 4
Eval:
dataset:
name: SimpleDataSet
data_dir: ./train_data/SROIE/test
label_file_list:
- ./train_data/SROIE/test.txt
transforms:
- DecodeImage: # load image
img_mode: RGB
channel_first: False
- VQATokenLabelEncode: # Class handling label
contains_re: False
algorithm: *algorithm
class_path: *class_path
- VQATokenPad:
max_seq_len: *max_seq_len
return_attention_mask: True
- VQASerTokenChunk:
max_seq_len: *max_seq_len
- Resize:
size: [224,224]
- NormalizeImage:
scale: 1
mean: [ 123.675, 116.28, 103.53 ]
std: [ 58.395, 57.12, 57.375 ]
order: 'hwc'
- ToCHWImage:
- KeepKeys:
# dataloader will return list in this order
keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image', 'labels']
loader:
shuffle: False
drop_last: False
batch_size_per_card: 8
num_workers: 4
Global:
use_gpu: True
epoch_num: &epoch_num 200
log_smooth_window: 10
print_batch_step: 10
save_model_dir: ./output/ser_layoutxlm_sroie
save_epoch_step: 2000
# evaluation is run every 10 iterations after the 0th iteration
eval_batch_step: [ 0, 200 ]
cal_metric_during_train: False
save_inference_dir:
use_visualdl: False
seed: 2022
infer_img: train_data/SROIE/test/X00016469670.jpg
save_res_path: res_img_aug_with_gt
Architecture:
model_type: vqa
algorithm: &algorithm "LayoutXLM"
Transform:
Backbone:
name: LayoutXLMForSer
pretrained: True
checkpoints:
num_classes: &num_classes 9
Loss:
name: VQASerTokenLayoutLMLoss
num_classes: *num_classes
Optimizer:
name: AdamW
beta1: 0.9
beta2: 0.999
lr:
name: Linear
learning_rate: 0.00005
epochs: *epoch_num
warmup_epoch: 2
regularizer:
name: L2
factor: 0.00000
PostProcess:
name: VQASerTokenLayoutLMPostProcess
class_path: &class_path ./train_data/SROIE/class_list.txt
Metric:
name: VQASerTokenMetric
main_indicator: hmean
Train:
dataset:
name: SimpleDataSet
data_dir: ./train_data/SROIE/train
label_file_list:
- ./train_data/SROIE/train.txt
ratio_list: [ 1.0 ]
transforms:
- DecodeImage: # load image
img_mode: RGB
channel_first: False
- VQATokenLabelEncode: # Class handling label
contains_re: False
algorithm: *algorithm
class_path: *class_path
- VQATokenPad:
max_seq_len: &max_seq_len 512
return_attention_mask: True
- VQASerTokenChunk:
max_seq_len: *max_seq_len
- Resize:
size: [224,224]
- NormalizeImage:
scale: 1
mean: [ 123.675, 116.28, 103.53 ]
std: [ 58.395, 57.12, 57.375 ]
order: 'hwc'
- ToCHWImage:
- KeepKeys:
# dataloader will return list in this order
keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image', 'labels']
loader:
shuffle: True
drop_last: False
batch_size_per_card: 8
num_workers: 4
Eval:
dataset:
name: SimpleDataSet
data_dir: train_data/SROIE/test
label_file_list:
- ./train_data/SROIE/test.txt
transforms:
- DecodeImage: # load image
img_mode: RGB
channel_first: False
- VQATokenLabelEncode: # Class handling label
contains_re: False
algorithm: *algorithm
class_path: *class_path
- VQATokenPad:
max_seq_len: *max_seq_len
return_attention_mask: True
- VQASerTokenChunk:
max_seq_len: *max_seq_len
- Resize:
size: [224,224]
- NormalizeImage:
scale: 1
mean: [ 123.675, 116.28, 103.53 ]
std: [ 58.395, 57.12, 57.375 ]
order: 'hwc'
- ToCHWImage:
- KeepKeys:
# dataloader will return list in this order
keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image', 'labels']
loader:
shuffle: False
drop_last: False
batch_size_per_card: 8
num_workers: 4
Global:
use_gpu: True
epoch_num: &epoch_num 100
log_smooth_window: 10
print_batch_step: 10
save_model_dir: ./output/ser_layoutxlm_wildreceipt
save_epoch_step: 2000
# evaluation is run every 10 iterations after the 0th iteration
eval_batch_step: [ 0, 200 ]
cal_metric_during_train: False
save_inference_dir:
use_visualdl: False
seed: 2022
infer_img: train_data//wildreceipt/image_files/Image_12/10/845be0dd6f5b04866a2042abd28d558032ef2576.jpeg
save_res_path: ./output/ser_layoutxlm_wildreceipt/res
Architecture:
model_type: vqa
algorithm: &algorithm "LayoutXLM"
Transform:
Backbone:
name: LayoutXLMForSer
pretrained: True
checkpoints:
num_classes: &num_classes 51
Loss:
name: VQASerTokenLayoutLMLoss
num_classes: *num_classes
Optimizer:
name: AdamW
beta1: 0.9
beta2: 0.999
lr:
name: Linear
learning_rate: 0.00005
epochs: *epoch_num
warmup_epoch: 2
regularizer:
name: L2
factor: 0.00000
PostProcess:
name: VQASerTokenLayoutLMPostProcess
class_path: &class_path ./train_data/wildreceipt/class_list.txt
Metric:
name: VQASerTokenMetric
main_indicator: hmean
Train:
dataset:
name: SimpleDataSet
data_dir: ./train_data/wildreceipt/
label_file_list:
- ./train_data/wildreceipt/wildreceipt_train.txt
ratio_list: [ 1.0 ]
transforms:
- DecodeImage: # load image
img_mode: RGB
channel_first: False
- VQATokenLabelEncode: # Class handling label
contains_re: False
algorithm: *algorithm
class_path: *class_path
- VQATokenPad:
max_seq_len: &max_seq_len 512
return_attention_mask: True
- VQASerTokenChunk:
max_seq_len: *max_seq_len
- Resize:
size: [224,224]
- NormalizeImage:
scale: 1
mean: [ 123.675, 116.28, 103.53 ]
std: [ 58.395, 57.12, 57.375 ]
order: 'hwc'
- ToCHWImage:
- KeepKeys:
# dataloader will return list in this order
keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image', 'labels']
loader:
shuffle: True
drop_last: False
batch_size_per_card: 8
num_workers: 4
Eval:
dataset:
name: SimpleDataSet
data_dir: train_data/wildreceipt
label_file_list:
- ./train_data/wildreceipt/wildreceipt_test.txt
transforms:
- DecodeImage: # load image
img_mode: RGB
channel_first: False
- VQATokenLabelEncode: # Class handling label
contains_re: False
algorithm: *algorithm
class_path: *class_path
- VQATokenPad:
max_seq_len: *max_seq_len
return_attention_mask: True
- VQASerTokenChunk:
max_seq_len: *max_seq_len
- Resize:
size: [224,224]
- NormalizeImage:
scale: 1
mean: [ 123.675, 116.28, 103.53 ]
std: [ 58.395, 57.12, 57.375 ]
order: 'hwc'
- ToCHWImage:
- KeepKeys:
# dataloader will return list in this order
keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image', 'labels']
loader:
shuffle: False
drop_last: False
batch_size_per_card: 8
num_workers: 4
...@@ -24,6 +24,7 @@ from shapely.geometry import LineString, Point, Polygon ...@@ -24,6 +24,7 @@ from shapely.geometry import LineString, Point, Polygon
import json import json
import copy import copy
from ppocr.utils.logging import get_logger from ppocr.utils.logging import get_logger
from ppocr.data.imaug.vqa.augment import order_by_tbyx
class ClsLabelEncode(object): class ClsLabelEncode(object):
...@@ -870,6 +871,7 @@ class VQATokenLabelEncode(object): ...@@ -870,6 +871,7 @@ class VQATokenLabelEncode(object):
add_special_ids=False, add_special_ids=False,
algorithm='LayoutXLM', algorithm='LayoutXLM',
use_textline_bbox_info=True, use_textline_bbox_info=True,
order_method=None,
infer_mode=False, infer_mode=False,
ocr_engine=None, ocr_engine=None,
**kwargs): **kwargs):
...@@ -899,6 +901,8 @@ class VQATokenLabelEncode(object): ...@@ -899,6 +901,8 @@ class VQATokenLabelEncode(object):
self.infer_mode = infer_mode self.infer_mode = infer_mode
self.ocr_engine = ocr_engine self.ocr_engine = ocr_engine
self.use_textline_bbox_info = use_textline_bbox_info self.use_textline_bbox_info = use_textline_bbox_info
self.order_method = order_method
assert self.order_method in [None, "tb-yx"]
def split_bbox(self, bbox, text, tokenizer): def split_bbox(self, bbox, text, tokenizer):
words = text.split() words = text.split()
...@@ -938,6 +942,14 @@ class VQATokenLabelEncode(object): ...@@ -938,6 +942,14 @@ class VQATokenLabelEncode(object):
# load bbox and label info # load bbox and label info
ocr_info = self._load_ocr_info(data) ocr_info = self._load_ocr_info(data)
for idx in range(len(ocr_info)):
if "bbox" not in ocr_info[idx]:
ocr_info[idx]["bbox"] = self.trans_poly_to_bbox(ocr_info[idx][
"points"])
if self.order_method == "tb-yx":
ocr_info = order_by_tbyx(ocr_info)
# for re # for re
train_re = self.contains_re and not self.infer_mode train_re = self.contains_re and not self.infer_mode
if train_re: if train_re:
...@@ -1052,10 +1064,10 @@ class VQATokenLabelEncode(object): ...@@ -1052,10 +1064,10 @@ class VQATokenLabelEncode(object):
return data return data
def trans_poly_to_bbox(self, poly): def trans_poly_to_bbox(self, poly):
x1 = np.min([p[0] for p in poly]) x1 = int(np.min([p[0] for p in poly]))
x2 = np.max([p[0] for p in poly]) x2 = int(np.max([p[0] for p in poly]))
y1 = np.min([p[1] for p in poly]) y1 = int(np.min([p[1] for p in poly]))
y2 = np.max([p[1] for p in poly]) y2 = int(np.max([p[1] for p in poly]))
return [x1, y1, x2, y2] return [x1, y1, x2, y2]
def _load_ocr_info(self, data): def _load_ocr_info(self, data):
......
...@@ -13,12 +13,10 @@ ...@@ -13,12 +13,10 @@
# limitations under the License. # limitations under the License.
from .token import VQATokenPad, VQASerTokenChunk, VQAReTokenChunk, VQAReTokenRelation from .token import VQATokenPad, VQASerTokenChunk, VQAReTokenChunk, VQAReTokenRelation
from .augment import DistortBBox
__all__ = [ __all__ = [
'VQATokenPad', 'VQATokenPad',
'VQASerTokenChunk', 'VQASerTokenChunk',
'VQAReTokenChunk', 'VQAReTokenChunk',
'VQAReTokenRelation', 'VQAReTokenRelation',
'DistortBBox',
] ]
...@@ -16,22 +16,18 @@ import os ...@@ -16,22 +16,18 @@ import os
import sys import sys
import numpy as np import numpy as np
import random import random
from copy import deepcopy
class DistortBBox: def order_by_tbyx(ocr_info):
def __init__(self, prob=0.5, max_scale=1, **kwargs): res = sorted(ocr_info, key=lambda r: (r["bbox"][1], r["bbox"][0]))
"""Random distort bbox for i in range(len(res) - 1):
""" for j in range(i, 0, -1):
self.prob = prob if abs(res[j + 1]["bbox"][1] - res[j]["bbox"][1]) < 20 and \
self.max_scale = max_scale (res[j + 1]["bbox"][0] < res[j]["bbox"][0]):
tmp = deepcopy(res[j])
def __call__(self, data): res[j] = deepcopy(res[j + 1])
if random.random() > self.prob: res[j + 1] = deepcopy(tmp)
return data else:
bbox = np.array(data['bbox']) break
rnd_scale = (np.random.rand(*bbox.shape) - 0.5) * 2 * self.max_scale return res
bbox = np.round(bbox + rnd_scale).astype(bbox.dtype)
data['bbox'] = np.clip(data['bbox'], 0, 1000)
data['bbox'] = bbox.tolist()
sys.stdout.flush()
return data
...@@ -63,18 +63,21 @@ class KLJSLoss(object): ...@@ -63,18 +63,21 @@ class KLJSLoss(object):
def __call__(self, p1, p2, reduction="mean"): def __call__(self, p1, p2, reduction="mean"):
if self.mode.lower() == 'kl': if self.mode.lower() == 'kl':
loss = paddle.multiply(p2, paddle.log((p2 + 1e-5) / (p1 + 1e-5) + 1e-5)) loss = paddle.multiply(p2,
paddle.log((p2 + 1e-5) / (p1 + 1e-5) + 1e-5))
loss += paddle.multiply( loss += paddle.multiply(
p1, paddle.log((p1 + 1e-5) / (p2 + 1e-5) + 1e-5)) p1, paddle.log((p1 + 1e-5) / (p2 + 1e-5) + 1e-5))
loss *= 0.5 loss *= 0.5
elif self.mode.lower() == "js": elif self.mode.lower() == "js":
loss = paddle.multiply(p2, paddle.log((2*p2 + 1e-5) / (p1 + p2 + 1e-5) + 1e-5)) loss = paddle.multiply(
p2, paddle.log((2 * p2 + 1e-5) / (p1 + p2 + 1e-5) + 1e-5))
loss += paddle.multiply( loss += paddle.multiply(
p1, paddle.log((2*p1 + 1e-5) / (p1 + p2 + 1e-5) + 1e-5)) p1, paddle.log((2 * p1 + 1e-5) / (p1 + p2 + 1e-5) + 1e-5))
loss *= 0.5 loss *= 0.5
else: else:
raise ValueError("The mode.lower() if KLJSLoss should be one of ['kl', 'js']") raise ValueError(
"The mode.lower() if KLJSLoss should be one of ['kl', 'js']")
if reduction == "mean": if reduction == "mean":
loss = paddle.mean(loss, axis=[1, 2]) loss = paddle.mean(loss, axis=[1, 2])
elif reduction == "none" or reduction is None: elif reduction == "none" or reduction is None:
...@@ -154,7 +157,9 @@ class LossFromOutput(nn.Layer): ...@@ -154,7 +157,9 @@ class LossFromOutput(nn.Layer):
self.reduction = reduction self.reduction = reduction
def forward(self, predicts, batch): def forward(self, predicts, batch):
loss = predicts[self.key] loss = predicts
if self.key is not None and isinstance(predicts, dict):
loss = loss[self.key]
if self.reduction == 'mean': if self.reduction == 'mean':
loss = paddle.mean(loss) loss = paddle.mean(loss)
elif self.reduction == 'sum': elif self.reduction == 'sum':
......
...@@ -24,6 +24,9 @@ from .distillation_loss import DistillationCTCLoss ...@@ -24,6 +24,9 @@ from .distillation_loss import DistillationCTCLoss
from .distillation_loss import DistillationSARLoss from .distillation_loss import DistillationSARLoss
from .distillation_loss import DistillationDMLLoss from .distillation_loss import DistillationDMLLoss
from .distillation_loss import DistillationDistanceLoss, DistillationDBLoss, DistillationDilaDBLoss from .distillation_loss import DistillationDistanceLoss, DistillationDBLoss, DistillationDilaDBLoss
from .distillation_loss import DistillationVQASerTokenLayoutLMLoss, DistillationSERDMLLoss
from .distillation_loss import DistillationLossFromOutput
from .distillation_loss import DistillationVQADistanceLoss
class CombinedLoss(nn.Layer): class CombinedLoss(nn.Layer):
......
...@@ -21,8 +21,10 @@ from .rec_ctc_loss import CTCLoss ...@@ -21,8 +21,10 @@ from .rec_ctc_loss import CTCLoss
from .rec_sar_loss import SARLoss from .rec_sar_loss import SARLoss
from .basic_loss import DMLLoss from .basic_loss import DMLLoss
from .basic_loss import DistanceLoss from .basic_loss import DistanceLoss
from .basic_loss import LossFromOutput
from .det_db_loss import DBLoss from .det_db_loss import DBLoss
from .det_basic_loss import BalanceLoss, MaskL1Loss, DiceLoss from .det_basic_loss import BalanceLoss, MaskL1Loss, DiceLoss
from .vqa_token_layoutlm_loss import VQASerTokenLayoutLMLoss
def _sum_loss(loss_dict): def _sum_loss(loss_dict):
...@@ -322,3 +324,133 @@ class DistillationDistanceLoss(DistanceLoss): ...@@ -322,3 +324,133 @@ class DistillationDistanceLoss(DistanceLoss):
loss_dict["{}_{}_{}_{}".format(self.name, pair[0], pair[1], loss_dict["{}_{}_{}_{}".format(self.name, pair[0], pair[1],
idx)] = loss idx)] = loss
return loss_dict return loss_dict
class DistillationVQASerTokenLayoutLMLoss(VQASerTokenLayoutLMLoss):
def __init__(self,
num_classes,
model_name_list=[],
key=None,
name="loss_ser"):
super().__init__(num_classes=num_classes)
self.model_name_list = model_name_list
self.key = key
self.name = name
def forward(self, predicts, batch):
loss_dict = dict()
for idx, model_name in enumerate(self.model_name_list):
out = predicts[model_name]
if self.key is not None:
out = out[self.key]
loss = super().forward(out, batch)
loss_dict["{}_{}".format(self.name, model_name)] = loss["loss"]
return loss_dict
class DistillationLossFromOutput(LossFromOutput):
def __init__(self,
reduction="none",
model_name_list=[],
dist_key=None,
key="loss",
name="loss_re"):
super().__init__(key=key, reduction=reduction)
self.model_name_list = model_name_list
self.name = name
self.dist_key = dist_key
def forward(self, predicts, batch):
loss_dict = dict()
for idx, model_name in enumerate(self.model_name_list):
out = predicts[model_name]
if self.dist_key is not None:
out = out[self.dist_key]
loss = super().forward(out, batch)
loss_dict["{}_{}".format(self.name, model_name)] = loss["loss"]
return loss_dict
class DistillationSERDMLLoss(DMLLoss):
"""
"""
def __init__(self,
act="softmax",
use_log=True,
num_classes=7,
model_name_pairs=[],
key=None,
name="loss_dml_ser"):
super().__init__(act=act, use_log=use_log)
assert isinstance(model_name_pairs, list)
self.key = key
self.name = name
self.num_classes = num_classes
self.model_name_pairs = model_name_pairs
def forward(self, predicts, batch):
loss_dict = dict()
for idx, pair in enumerate(self.model_name_pairs):
out1 = predicts[pair[0]]
out2 = predicts[pair[1]]
if self.key is not None:
out1 = out1[self.key]
out2 = out2[self.key]
out1 = out1.reshape([-1, out1.shape[-1]])
out2 = out2.reshape([-1, out2.shape[-1]])
attention_mask = batch[2]
if attention_mask is not None:
active_output = attention_mask.reshape([-1, ]) == 1
out1 = out1[active_output]
out2 = out2[active_output]
loss_dict["{}_{}".format(self.name, idx)] = super().forward(out1,
out2)
return loss_dict
class DistillationVQADistanceLoss(DistanceLoss):
def __init__(self,
mode="l2",
model_name_pairs=[],
key=None,
name="loss_distance",
**kargs):
super().__init__(mode=mode, **kargs)
assert isinstance(model_name_pairs, list)
self.key = key
self.model_name_pairs = model_name_pairs
self.name = name + "_l2"
def forward(self, predicts, batch):
loss_dict = dict()
for idx, pair in enumerate(self.model_name_pairs):
out1 = predicts[pair[0]]
out2 = predicts[pair[1]]
attention_mask = batch[2]
if self.key is not None:
out1 = out1[self.key]
out2 = out2[self.key]
if attention_mask is not None:
max_len = attention_mask.shape[-1]
out1 = out1[:, :max_len]
out2 = out2[:, :max_len]
out1 = out1.reshape([-1, out1.shape[-1]])
out2 = out2.reshape([-1, out2.shape[-1]])
if attention_mask is not None:
active_output = attention_mask.reshape([-1, ]) == 1
out1 = out1[active_output]
out2 = out2[active_output]
loss = super().forward(out1, out2)
if isinstance(loss, dict):
for key in loss:
loss_dict["{}_{}nohu_{}".format(self.name, key,
idx)] = loss[key]
else:
loss_dict["{}_{}_{}_{}".format(self.name, pair[0], pair[1],
idx)] = loss
return loss_dict
...@@ -17,26 +17,30 @@ from __future__ import division ...@@ -17,26 +17,30 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
from paddle import nn from paddle import nn
from ppocr.losses.basic_loss import DMLLoss
class VQASerTokenLayoutLMLoss(nn.Layer): class VQASerTokenLayoutLMLoss(nn.Layer):
def __init__(self, num_classes): def __init__(self, num_classes, key=None):
super().__init__() super().__init__()
self.loss_class = nn.CrossEntropyLoss() self.loss_class = nn.CrossEntropyLoss()
self.num_classes = num_classes self.num_classes = num_classes
self.ignore_index = self.loss_class.ignore_index self.ignore_index = self.loss_class.ignore_index
self.key = key
def forward(self, predicts, batch): def forward(self, predicts, batch):
if isinstance(predicts, dict) and self.key is not None:
predicts = predicts[self.key]
labels = batch[5] labels = batch[5]
attention_mask = batch[2] attention_mask = batch[2]
if attention_mask is not None: if attention_mask is not None:
active_loss = attention_mask.reshape([-1, ]) == 1 active_loss = attention_mask.reshape([-1, ]) == 1
active_outputs = predicts.reshape( active_output = predicts.reshape(
[-1, self.num_classes])[active_loss] [-1, self.num_classes])[active_loss]
active_labels = labels.reshape([-1, ])[active_loss] active_label = labels.reshape([-1, ])[active_loss]
loss = self.loss_class(active_outputs, active_labels) loss = self.loss_class(active_output, active_label)
else: else:
loss = self.loss_class( loss = self.loss_class(
predicts.reshape([-1, self.num_classes]), predicts.reshape([-1, self.num_classes]),
labels.reshape([-1, ])) labels.reshape([-1, ]))
return {'loss': loss} return {'loss': loss}
\ No newline at end of file
...@@ -19,6 +19,8 @@ from .rec_metric import RecMetric ...@@ -19,6 +19,8 @@ from .rec_metric import RecMetric
from .det_metric import DetMetric from .det_metric import DetMetric
from .e2e_metric import E2EMetric from .e2e_metric import E2EMetric
from .cls_metric import ClsMetric from .cls_metric import ClsMetric
from .vqa_token_ser_metric import VQASerTokenMetric
from .vqa_token_re_metric import VQAReTokenMetric
class DistillationMetric(object): class DistillationMetric(object):
......
...@@ -73,28 +73,40 @@ class BaseModel(nn.Layer): ...@@ -73,28 +73,40 @@ class BaseModel(nn.Layer):
self.return_all_feats = config.get("return_all_feats", False) self.return_all_feats = config.get("return_all_feats", False)
def forward(self, x, data=None): def forward(self, x, data=None):
y = dict() y = dict()
if self.use_transform: if self.use_transform:
x = self.transform(x) x = self.transform(x)
x = self.backbone(x) x = self.backbone(x)
y["backbone_out"] = x if isinstance(x, dict):
y.update(x)
else:
y["backbone_out"] = x
final_name = "backbone_out"
if self.use_neck: if self.use_neck:
x = self.neck(x) x = self.neck(x)
y["neck_out"] = x if isinstance(x, dict):
y.update(x)
else:
y["neck_out"] = x
final_name = "neck_out"
if self.use_head: if self.use_head:
x = self.head(x, targets=data) x = self.head(x, targets=data)
# for multi head, save ctc neck out for udml # for multi head, save ctc neck out for udml
if isinstance(x, dict) and 'ctc_neck' in x.keys(): if isinstance(x, dict) and 'ctc_neck' in x.keys():
y["neck_out"] = x["ctc_neck"] y["neck_out"] = x["ctc_neck"]
y["head_out"] = x y["head_out"] = x
elif isinstance(x, dict): elif isinstance(x, dict):
y.update(x) y.update(x)
else: else:
y["head_out"] = x y["head_out"] = x
final_name = "head_out"
if self.return_all_feats: if self.return_all_feats:
if self.training: if self.training:
return y return y
elif isinstance(x, dict):
return x
else: else:
return {"head_out": y["head_out"]} return {final_name: x}
else: else:
return x return x
...@@ -22,13 +22,22 @@ from paddle import nn ...@@ -22,13 +22,22 @@ from paddle import nn
from paddlenlp.transformers import LayoutXLMModel, LayoutXLMForTokenClassification, LayoutXLMForRelationExtraction from paddlenlp.transformers import LayoutXLMModel, LayoutXLMForTokenClassification, LayoutXLMForRelationExtraction
from paddlenlp.transformers import LayoutLMModel, LayoutLMForTokenClassification from paddlenlp.transformers import LayoutLMModel, LayoutLMForTokenClassification
from paddlenlp.transformers import LayoutLMv2Model, LayoutLMv2ForTokenClassification, LayoutLMv2ForRelationExtraction from paddlenlp.transformers import LayoutLMv2Model, LayoutLMv2ForTokenClassification, LayoutLMv2ForRelationExtraction
from paddlenlp.transformers import AutoModel
__all__ = ["LayoutXLMForSer", 'LayoutLMForSer'] __all__ = ["LayoutXLMForSer", "LayoutLMForSer"]
pretrained_model_dict = { pretrained_model_dict = {
LayoutXLMModel: 'layoutxlm-base-uncased', LayoutXLMModel: {
LayoutLMModel: 'layoutlm-base-uncased', "base": "layoutxlm-base-uncased",
LayoutLMv2Model: 'layoutlmv2-base-uncased' "vi": "layoutxlm-wo-backbone-base-uncased",
},
LayoutLMModel: {
"base": "layoutlm-base-uncased",
},
LayoutLMv2Model: {
"base": "layoutlmv2-base-uncased",
"vi": "layoutlmv2-wo-backbone-base-uncased",
},
} }
...@@ -36,42 +45,47 @@ class NLPBaseModel(nn.Layer): ...@@ -36,42 +45,47 @@ class NLPBaseModel(nn.Layer):
def __init__(self, def __init__(self,
base_model_class, base_model_class,
model_class, model_class,
type='ser', mode="base",
type="ser",
pretrained=True, pretrained=True,
checkpoints=None, checkpoints=None,
**kwargs): **kwargs):
super(NLPBaseModel, self).__init__() super(NLPBaseModel, self).__init__()
if checkpoints is not None: if checkpoints is not None: # load the trained model
self.model = model_class.from_pretrained(checkpoints) self.model = model_class.from_pretrained(checkpoints)
elif isinstance(pretrained, (str, )) and os.path.exists(pretrained): else: # load the pretrained-model
self.model = model_class.from_pretrained(pretrained) pretrained_model_name = pretrained_model_dict[base_model_class][
else: mode]
pretrained_model_name = pretrained_model_dict[base_model_class]
if pretrained is True: if pretrained is True:
base_model = base_model_class.from_pretrained( base_model = base_model_class.from_pretrained(
pretrained_model_name) pretrained_model_name)
else: else:
base_model = base_model_class( base_model = base_model_class.from_pretrained(pretrained)
**base_model_class.pretrained_init_configuration[ if type == "ser":
pretrained_model_name])
if type == 'ser':
self.model = model_class( self.model = model_class(
base_model, num_classes=kwargs['num_classes'], dropout=None) base_model, num_classes=kwargs["num_classes"], dropout=None)
else: else:
self.model = model_class(base_model, dropout=None) self.model = model_class(base_model, dropout=None)
self.out_channels = 1 self.out_channels = 1
self.use_visual_backbone = True
class LayoutLMForSer(NLPBaseModel): class LayoutLMForSer(NLPBaseModel):
def __init__(self, num_classes, pretrained=True, checkpoints=None, def __init__(self,
num_classes,
pretrained=True,
checkpoints=None,
mode="base",
**kwargs): **kwargs):
super(LayoutLMForSer, self).__init__( super(LayoutLMForSer, self).__init__(
LayoutLMModel, LayoutLMModel,
LayoutLMForTokenClassification, LayoutLMForTokenClassification,
'ser', mode,
"ser",
pretrained, pretrained,
checkpoints, checkpoints,
num_classes=num_classes) num_classes=num_classes, )
self.use_visual_backbone = False
def forward(self, x): def forward(self, x):
x = self.model( x = self.model(
...@@ -85,62 +99,92 @@ class LayoutLMForSer(NLPBaseModel): ...@@ -85,62 +99,92 @@ class LayoutLMForSer(NLPBaseModel):
class LayoutLMv2ForSer(NLPBaseModel): class LayoutLMv2ForSer(NLPBaseModel):
def __init__(self, num_classes, pretrained=True, checkpoints=None, def __init__(self,
num_classes,
pretrained=True,
checkpoints=None,
mode="base",
**kwargs): **kwargs):
super(LayoutLMv2ForSer, self).__init__( super(LayoutLMv2ForSer, self).__init__(
LayoutLMv2Model, LayoutLMv2Model,
LayoutLMv2ForTokenClassification, LayoutLMv2ForTokenClassification,
'ser', mode,
"ser",
pretrained, pretrained,
checkpoints, checkpoints,
num_classes=num_classes) num_classes=num_classes)
self.use_visual_backbone = True
if hasattr(self.model.layoutlmv2, "use_visual_backbone"
) and self.model.layoutlmv2.use_visual_backbone is False:
self.use_visual_backbone = False
def forward(self, x): def forward(self, x):
if self.use_visual_backbone is True:
image = x[4]
else:
image = None
x = self.model( x = self.model(
input_ids=x[0], input_ids=x[0],
bbox=x[1], bbox=x[1],
attention_mask=x[2], attention_mask=x[2],
token_type_ids=x[3], token_type_ids=x[3],
image=x[4], image=image,
position_ids=None, position_ids=None,
head_mask=None, head_mask=None,
labels=None) labels=None)
if not self.training: if self.training:
res = {"backbone_out": x[0]}
res.update(x[1])
return res
else:
return x return x
return x[0]
class LayoutXLMForSer(NLPBaseModel): class LayoutXLMForSer(NLPBaseModel):
def __init__(self, num_classes, pretrained=True, checkpoints=None, def __init__(self,
num_classes,
pretrained=True,
checkpoints=None,
mode="base",
**kwargs): **kwargs):
super(LayoutXLMForSer, self).__init__( super(LayoutXLMForSer, self).__init__(
LayoutXLMModel, LayoutXLMModel,
LayoutXLMForTokenClassification, LayoutXLMForTokenClassification,
'ser', mode,
"ser",
pretrained, pretrained,
checkpoints, checkpoints,
num_classes=num_classes) num_classes=num_classes)
self.use_visual_backbone = True
def forward(self, x): def forward(self, x):
if self.use_visual_backbone is True:
image = x[4]
else:
image = None
x = self.model( x = self.model(
input_ids=x[0], input_ids=x[0],
bbox=x[1], bbox=x[1],
attention_mask=x[2], attention_mask=x[2],
token_type_ids=x[3], token_type_ids=x[3],
image=x[4], image=image,
position_ids=None, position_ids=None,
head_mask=None, head_mask=None,
labels=None) labels=None)
if not self.training: if self.training:
res = {"backbone_out": x[0]}
res.update(x[1])
return res
else:
return x return x
return x[0]
class LayoutLMv2ForRe(NLPBaseModel): class LayoutLMv2ForRe(NLPBaseModel):
def __init__(self, pretrained=True, checkpoints=None, **kwargs): def __init__(self, pretrained=True, checkpoints=None, mode="base",
super(LayoutLMv2ForRe, self).__init__(LayoutLMv2Model, **kwargs):
LayoutLMv2ForRelationExtraction, super(LayoutLMv2ForRe, self).__init__(
're', pretrained, checkpoints) LayoutLMv2Model, LayoutLMv2ForRelationExtraction, mode, "re",
pretrained, checkpoints)
def forward(self, x): def forward(self, x):
x = self.model( x = self.model(
...@@ -158,18 +202,27 @@ class LayoutLMv2ForRe(NLPBaseModel): ...@@ -158,18 +202,27 @@ class LayoutLMv2ForRe(NLPBaseModel):
class LayoutXLMForRe(NLPBaseModel): class LayoutXLMForRe(NLPBaseModel):
def __init__(self, pretrained=True, checkpoints=None, **kwargs): def __init__(self, pretrained=True, checkpoints=None, mode="base",
super(LayoutXLMForRe, self).__init__(LayoutXLMModel, **kwargs):
LayoutXLMForRelationExtraction, super(LayoutXLMForRe, self).__init__(
're', pretrained, checkpoints) LayoutXLMModel, LayoutXLMForRelationExtraction, mode, "re",
pretrained, checkpoints)
self.use_visual_backbone = True
if hasattr(self.model.layoutxlm, "use_visual_backbone"
) and self.model.layoutxlm.use_visual_backbone is False:
self.use_visual_backbone = False
def forward(self, x): def forward(self, x):
if self.use_visual_backbone is True:
image = x[4]
else:
image = None
x = self.model( x = self.model(
input_ids=x[0], input_ids=x[0],
bbox=x[1], bbox=x[1],
attention_mask=x[2], attention_mask=x[2],
token_type_ids=x[3], token_type_ids=x[3],
image=x[4], image=image,
position_ids=None, position_ids=None,
head_mask=None, head_mask=None,
labels=None, labels=None,
......
...@@ -31,21 +31,38 @@ from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode, \ ...@@ -31,21 +31,38 @@ from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode, \
SPINLabelDecode SPINLabelDecode
from .cls_postprocess import ClsPostProcess from .cls_postprocess import ClsPostProcess
from .pg_postprocess import PGPostProcess from .pg_postprocess import PGPostProcess
from .vqa_token_ser_layoutlm_postprocess import VQASerTokenLayoutLMPostProcess from .vqa_token_ser_layoutlm_postprocess import VQASerTokenLayoutLMPostProcess, DistillationSerPostProcess
from .vqa_token_re_layoutlm_postprocess import VQAReTokenLayoutLMPostProcess from .vqa_token_re_layoutlm_postprocess import VQAReTokenLayoutLMPostProcess, DistillationRePostProcess
from .table_postprocess import TableMasterLabelDecode, TableLabelDecode from .table_postprocess import TableMasterLabelDecode, TableLabelDecode
def build_post_process(config, global_config=None): def build_post_process(config, global_config=None):
support_dict = [ support_dict = [
'DBPostProcess', 'EASTPostProcess', 'SASTPostProcess', 'FCEPostProcess', 'DBPostProcess',
'CTCLabelDecode', 'AttnLabelDecode', 'ClsPostProcess', 'SRNLabelDecode', 'EASTPostProcess',
'PGPostProcess', 'DistillationCTCLabelDecode', 'TableLabelDecode', 'SASTPostProcess',
'DistillationDBPostProcess', 'NRTRLabelDecode', 'SARLabelDecode', 'FCEPostProcess',
'SEEDLabelDecode', 'VQASerTokenLayoutLMPostProcess', 'CTCLabelDecode',
'VQAReTokenLayoutLMPostProcess', 'PRENLabelDecode', 'AttnLabelDecode',
'DistillationSARLabelDecode', 'ViTSTRLabelDecode', 'ABINetLabelDecode', 'ClsPostProcess',
'TableMasterLabelDecode', 'SPINLabelDecode' 'SRNLabelDecode',
'PGPostProcess',
'DistillationCTCLabelDecode',
'TableLabelDecode',
'DistillationDBPostProcess',
'NRTRLabelDecode',
'SARLabelDecode',
'SEEDLabelDecode',
'VQASerTokenLayoutLMPostProcess',
'VQAReTokenLayoutLMPostProcess',
'PRENLabelDecode',
'DistillationSARLabelDecode',
'ViTSTRLabelDecode',
'ABINetLabelDecode',
'TableMasterLabelDecode',
'SPINLabelDecode',
'DistillationSerPostProcess',
'DistillationRePostProcess',
] ]
if config['name'] == 'PSEPostProcess': if config['name'] == 'PSEPostProcess':
......
...@@ -49,3 +49,25 @@ class VQAReTokenLayoutLMPostProcess(object): ...@@ -49,3 +49,25 @@ class VQAReTokenLayoutLMPostProcess(object):
result.append((ocr_info_head, ocr_info_tail)) result.append((ocr_info_head, ocr_info_tail))
results.append(result) results.append(result)
return results return results
class DistillationRePostProcess(VQAReTokenLayoutLMPostProcess):
"""
DistillationRePostProcess
"""
def __init__(self, model_name=["Student"], key=None, **kwargs):
super().__init__(**kwargs)
if not isinstance(model_name, list):
model_name = [model_name]
self.model_name = model_name
self.key = key
def __call__(self, preds, *args, **kwargs):
output = dict()
for name in self.model_name:
pred = preds[name]
if self.key is not None:
pred = pred[self.key]
output[name] = super().__call__(pred, *args, **kwargs)
return output
...@@ -93,3 +93,25 @@ class VQASerTokenLayoutLMPostProcess(object): ...@@ -93,3 +93,25 @@ class VQASerTokenLayoutLMPostProcess(object):
ocr_info[idx]["pred"] = self.id2label_map_for_show[int(pred_id)] ocr_info[idx]["pred"] = self.id2label_map_for_show[int(pred_id)]
results.append(ocr_info) results.append(ocr_info)
return results return results
class DistillationSerPostProcess(VQASerTokenLayoutLMPostProcess):
"""
DistillationSerPostProcess
"""
def __init__(self, class_path, model_name=["Student"], key=None, **kwargs):
super().__init__(class_path, **kwargs)
if not isinstance(model_name, list):
model_name = [model_name]
self.model_name = model_name
self.key = key
def __call__(self, preds, batch=None, *args, **kwargs):
output = dict()
for name in self.model_name:
pred = preds[name]
if self.key is not None:
pred = pred[self.key]
output[name] = super().__call__(pred, batch=batch, *args, **kwargs)
return output
...@@ -55,6 +55,9 @@ def load_model(config, model, optimizer=None, model_type='det'): ...@@ -55,6 +55,9 @@ def load_model(config, model, optimizer=None, model_type='det'):
best_model_dict = {} best_model_dict = {}
if model_type == 'vqa': if model_type == 'vqa':
# NOTE: for vqa model, resume training is not supported now
if config["Architecture"]["algorithm"] in ["Distillation"]:
return best_model_dict
checkpoints = config['Architecture']['Backbone']['checkpoints'] checkpoints = config['Architecture']['Backbone']['checkpoints']
# load vqa method metric # load vqa method metric
if checkpoints: if checkpoints:
...@@ -78,6 +81,7 @@ def load_model(config, model, optimizer=None, model_type='det'): ...@@ -78,6 +81,7 @@ def load_model(config, model, optimizer=None, model_type='det'):
logger.warning( logger.warning(
"{}.pdopt is not exists, params of optimizer is not loaded". "{}.pdopt is not exists, params of optimizer is not loaded".
format(checkpoints)) format(checkpoints))
return best_model_dict return best_model_dict
if checkpoints: if checkpoints:
...@@ -166,15 +170,19 @@ def save_model(model, ...@@ -166,15 +170,19 @@ def save_model(model,
""" """
_mkdir_if_not_exist(model_path, logger) _mkdir_if_not_exist(model_path, logger)
model_prefix = os.path.join(model_path, prefix) model_prefix = os.path.join(model_path, prefix)
paddle.save(optimizer.state_dict(), model_prefix + '.pdopt') if config['Architecture']["model_type"] != 'vqa':
paddle.save(optimizer.state_dict(), model_prefix + '.pdopt')
if config['Architecture']["model_type"] != 'vqa': if config['Architecture']["model_type"] != 'vqa':
paddle.save(model.state_dict(), model_prefix + '.pdparams') paddle.save(model.state_dict(), model_prefix + '.pdparams')
metric_prefix = model_prefix metric_prefix = model_prefix
else: else: # for vqa system, we follow the save/load rules in NLP
if config['Global']['distributed']: if config['Global']['distributed']:
model._layers.backbone.model.save_pretrained(model_prefix) arch = model._layers
else: else:
model.backbone.model.save_pretrained(model_prefix) arch = model
if config["Architecture"]["algorithm"] in ["Distillation"]:
arch = arch.Student
arch.backbone.model.save_pretrained(model_prefix)
metric_prefix = os.path.join(model_prefix, 'metric') metric_prefix = os.path.join(model_prefix, 'metric')
# save metric and config # save metric and config
with open(metric_prefix + '.states', 'wb') as f: with open(metric_prefix + '.states', 'wb') as f:
......
...@@ -38,6 +38,7 @@ def init_args(): ...@@ -38,6 +38,7 @@ def init_args():
parser.add_argument("--ir_optim", type=str2bool, default=True) parser.add_argument("--ir_optim", type=str2bool, default=True)
parser.add_argument("--use_tensorrt", type=str2bool, default=False) parser.add_argument("--use_tensorrt", type=str2bool, default=False)
parser.add_argument("--min_subgraph_size", type=int, default=15) parser.add_argument("--min_subgraph_size", type=int, default=15)
parser.add_argument("--shape_info_filename", type=str, default=None)
parser.add_argument("--precision", type=str, default="fp32") parser.add_argument("--precision", type=str, default="fp32")
parser.add_argument("--gpu_mem", type=int, default=500) parser.add_argument("--gpu_mem", type=int, default=500)
...@@ -204,9 +205,18 @@ def create_predictor(args, mode, logger): ...@@ -204,9 +205,18 @@ def create_predictor(args, mode, logger):
workspace_size=1 << 30, workspace_size=1 << 30,
precision_mode=precision, precision_mode=precision,
max_batch_size=args.max_batch_size, max_batch_size=args.max_batch_size,
min_subgraph_size=args.min_subgraph_size, min_subgraph_size=args.min_subgraph_size, # skip the minmum trt subgraph
use_calib_mode=False) use_calib_mode=False)
# skip the minmum trt subgraph
# collect shape
if args.shape_info_filename is not None:
if not os.path.exists(args.shape_info_filename):
config.collect_shape_range_info(args.shape_info_filename)
logger.info(f"collect dynamic shape info into : {args.shape_info_filename}")
else:
logger.info(f"dynamic shape info file( {args.shape_info_filename} ) already exists, not need to generate again.")
config.enable_tuned_tensorrt_dynamic_shape(args.shape_info_filename, True)
use_dynamic_shape = True use_dynamic_shape = True
if mode == "det": if mode == "det":
min_input_shape = { min_input_shape = {
......
...@@ -113,10 +113,13 @@ def make_input(ser_inputs, ser_results): ...@@ -113,10 +113,13 @@ def make_input(ser_inputs, ser_results):
class SerRePredictor(object): class SerRePredictor(object):
def __init__(self, config, ser_config): def __init__(self, config, ser_config):
global_config = config['Global']
if "infer_mode" in global_config:
ser_config["Global"]["infer_mode"] = global_config["infer_mode"]
self.ser_engine = SerPredictor(ser_config) self.ser_engine = SerPredictor(ser_config)
# init re model # init re model
global_config = config['Global']
# build post process # build post process
self.post_process_class = build_post_process(config['PostProcess'], self.post_process_class = build_post_process(config['PostProcess'],
...@@ -130,8 +133,8 @@ class SerRePredictor(object): ...@@ -130,8 +133,8 @@ class SerRePredictor(object):
self.model.eval() self.model.eval()
def __call__(self, img_path): def __call__(self, data):
ser_results, ser_inputs = self.ser_engine({'img_path': img_path}) ser_results, ser_inputs = self.ser_engine(data)
re_input, entity_idx_dict_batch = make_input(ser_inputs, ser_results) re_input, entity_idx_dict_batch = make_input(ser_inputs, ser_results)
preds = self.model(re_input) preds = self.model(re_input)
post_result = self.post_process_class( post_result = self.post_process_class(
...@@ -173,18 +176,33 @@ if __name__ == '__main__': ...@@ -173,18 +176,33 @@ if __name__ == '__main__':
ser_re_engine = SerRePredictor(config, ser_config) ser_re_engine = SerRePredictor(config, ser_config)
infer_imgs = get_image_file_list(config['Global']['infer_img']) if config["Global"].get("infer_mode", None) is False:
data_dir = config['Eval']['dataset']['data_dir']
with open(config['Global']['infer_img'], "rb") as f:
infer_imgs = f.readlines()
else:
infer_imgs = get_image_file_list(config['Global']['infer_img'])
with open( with open(
os.path.join(config['Global']['save_res_path'], os.path.join(config['Global']['save_res_path'],
"infer_results.txt"), "infer_results.txt"),
"w", "w",
encoding='utf-8') as fout: encoding='utf-8') as fout:
for idx, img_path in enumerate(infer_imgs): for idx, info in enumerate(infer_imgs):
if config["Global"].get("infer_mode", None) is False:
data_line = info.decode('utf-8')
substr = data_line.strip("\n").split("\t")
img_path = os.path.join(data_dir, substr[0])
data = {'img_path': img_path, 'label': substr[1]}
else:
img_path = info
data = {'img_path': img_path}
save_img_path = os.path.join( save_img_path = os.path.join(
config['Global']['save_res_path'], config['Global']['save_res_path'],
os.path.splitext(os.path.basename(img_path))[0] + "_ser_re.jpg") os.path.splitext(os.path.basename(img_path))[0] + "_ser_re.jpg")
result = ser_re_engine(img_path) result = ser_re_engine(data)
result = result[0] result = result[0]
fout.write(img_path + "\t" + json.dumps( fout.write(img_path + "\t" + json.dumps(
{ {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册