未验证 提交 9e4ae9dc 编写于 作者: littletomatodonkey's avatar littletomatodonkey 提交者: GitHub

add vqa code (#7096)

* add vqa code

* add order ocr info

* rename tb-yx order

* polish configs

* add trt offline-tuning

* fix seed and remove unused configs
上级 8c4e1dcc
......@@ -6,11 +6,11 @@ Global:
save_model_dir: ./output/re_layoutlmv2_xfund_zh
save_epoch_step: 2000
# evaluation is run every 10 iterations after the 0th iteration
eval_batch_step: [ 0, 57 ]
eval_batch_step: [ 0, 19 ]
cal_metric_during_train: False
save_inference_dir:
use_visualdl: False
seed: 2048
seed: 2022
infer_img: ppstructure/docs/vqa/input/zh_val_21.jpg
save_res_path: ./output/re_layoutlmv2_xfund_zh/res/
......
Global:
use_gpu: True
epoch_num: &epoch_num 200
epoch_num: &epoch_num 130
log_smooth_window: 10
print_batch_step: 10
save_model_dir: ./output/re_layoutxlm/
save_model_dir: ./output/re_layoutxlm_xfund_zh
save_epoch_step: 2000
# evaluation is run every 10 iterations after the 0th iteration
eval_batch_step: [ 0, 19 ]
......@@ -12,7 +12,7 @@ Global:
use_visualdl: False
seed: 2022
infer_img: ppstructure/docs/vqa/input/zh_val_21.jpg
save_res_path: ./output/re/
save_res_path: ./output/re_layoutxlm_xfund_zh/res/
Architecture:
model_type: vqa
......@@ -81,7 +81,7 @@ Train:
loader:
shuffle: True
drop_last: False
batch_size_per_card: 8
batch_size_per_card: 2
num_workers: 8
collate_fn: ListCollator
......
......@@ -6,13 +6,13 @@ Global:
save_model_dir: ./output/ser_layoutlm_xfund_zh
save_epoch_step: 2000
# evaluation is run every 10 iterations after the 0th iteration
eval_batch_step: [ 0, 57 ]
eval_batch_step: [ 0, 19 ]
cal_metric_during_train: False
save_inference_dir:
use_visualdl: False
seed: 2022
infer_img: ppstructure/docs/vqa/input/zh_val_42.jpg
save_res_path: ./output/ser_layoutlm_xfund_zh/res/
save_res_path: ./output/re_layoutlm_xfund_zh/res
Architecture:
model_type: vqa
......@@ -55,6 +55,7 @@ Train:
data_dir: train_data/XFUND/zh_train/image
label_file_list:
- train_data/XFUND/zh_train/train.json
ratio_list: [ 1.0 ]
transforms:
- DecodeImage: # load image
img_mode: RGB
......
......@@ -27,6 +27,7 @@ Architecture:
Loss:
name: VQASerTokenLayoutLMLoss
num_classes: *num_classes
key: "backbone_out"
Optimizer:
name: AdamW
......
......@@ -27,6 +27,7 @@ Architecture:
Loss:
name: VQASerTokenLayoutLMLoss
num_classes: *num_classes
key: "backbone_out"
Optimizer:
name: AdamW
......
Global:
use_gpu: True
epoch_num: &epoch_num 200
epoch_num: &epoch_num 130
log_smooth_window: 10
print_batch_step: 10
save_model_dir: ./output/re_layoutxlm_funsd
save_model_dir: ./output/re_vi_layoutxlm_xfund_zh
save_epoch_step: 2000
# evaluation is run every 10 iterations after the 0th iteration
eval_batch_step: [ 0, 57 ]
eval_batch_step: [ 0, 19 ]
cal_metric_during_train: False
save_inference_dir:
use_visualdl: False
seed: 2022
infer_img: train_data/FUNSD/testing_data/images/83624198.png
save_res_path: ./output/re_layoutxlm_funsd/res/
infer_img: ppstructure/docs/vqa/input/zh_val_21.jpg
save_res_path: ./output/re/xfund_zh/with_gt
Architecture:
model_type: vqa
......@@ -21,6 +21,7 @@ Architecture:
Backbone:
name: LayoutXLMForRe
pretrained: True
mode: vi
checkpoints:
Loss:
......@@ -50,10 +51,9 @@ Metric:
Train:
dataset:
name: SimpleDataSet
data_dir: ./train_data/FUNSD/training_data/images/
data_dir: train_data/XFUND/zh_train/image
label_file_list:
- ./train_data/FUNSD/train_v4.json
# - ./train_data/FUNSD/train.json
- train_data/XFUND/zh_train/train.json
ratio_list: [ 1.0 ]
transforms:
- DecodeImage: # load image
......@@ -62,8 +62,9 @@ Train:
- VQATokenLabelEncode: # Class handling label
contains_re: True
algorithm: *algorithm
class_path: &class_path ./train_data/FUNSD/class_list.txt
class_path: &class_path train_data/XFUND/class_list_xfun.txt
use_textline_bbox_info: &use_textline_bbox_info True
order_method: &order_method "tb-yx"
- VQATokenPad:
max_seq_len: &max_seq_len 512
return_attention_mask: True
......@@ -79,22 +80,20 @@ Train:
order: 'hwc'
- ToCHWImage:
- KeepKeys:
# dataloader will return list in this order
keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image', 'entities', 'relations']
keep_keys: [ 'input_ids', 'bbox','attention_mask', 'token_type_ids', 'image', 'entities', 'relations'] # dataloader will return list in this order
loader:
shuffle: False
shuffle: True
drop_last: False
batch_size_per_card: 8
num_workers: 16
batch_size_per_card: 2
num_workers: 4
collate_fn: ListCollator
Eval:
dataset:
name: SimpleDataSet
data_dir: ./train_data/FUNSD/testing_data/images/
label_file_list:
- ./train_data/FUNSD/test_v4.json
# - ./train_data/FUNSD/test.json
data_dir: train_data/XFUND/zh_val/image
label_file_list:
- train_data/XFUND/zh_val/val.json
transforms:
- DecodeImage: # load image
img_mode: RGB
......@@ -104,6 +103,7 @@ Eval:
algorithm: *algorithm
class_path: *class_path
use_textline_bbox_info: *use_textline_bbox_info
order_method: *order_method
- VQATokenPad:
max_seq_len: *max_seq_len
return_attention_mask: True
......@@ -119,11 +119,11 @@ Eval:
order: 'hwc'
- ToCHWImage:
- KeepKeys:
# dataloader will return list in this order
keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image', 'entities', 'relations']
keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image', 'entities', 'relations'] # dataloader will return list in this order
loader:
shuffle: False
drop_last: False
batch_size_per_card: 8
num_workers: 8
collate_fn: ListCollator
Global:
use_gpu: True
epoch_num: &epoch_num 130
log_smooth_window: 10
print_batch_step: 10
save_model_dir: ./output/re_vi_layoutxlm_xfund_zh_udml
save_epoch_step: 2000
# evaluation is run every 10 iterations after the 0th iteration
eval_batch_step: [ 0, 19 ]
cal_metric_during_train: False
save_inference_dir:
use_visualdl: False
seed: 2022
infer_img: ppstructure/docs/vqa/input/zh_val_21.jpg
save_res_path: ./output/re/xfund_zh/with_gt
Architecture:
model_type: &model_type "vqa"
name: DistillationModel
algorithm: Distillation
Models:
Teacher:
pretrained:
freeze_params: false
return_all_feats: true
model_type: *model_type
algorithm: &algorithm "LayoutXLM"
Transform:
Backbone:
name: LayoutXLMForRe
pretrained: True
mode: vi
checkpoints:
Student:
pretrained:
freeze_params: false
return_all_feats: true
model_type: *model_type
algorithm: *algorithm
Transform:
Backbone:
name: LayoutXLMForRe
pretrained: True
mode: vi
checkpoints:
Loss:
name: CombinedLoss
loss_config_list:
- DistillationLossFromOutput:
weight: 1.0
model_name_list: ["Student", "Teacher"]
key: loss
reduction: mean
- DistillationVQADistanceLoss:
weight: 0.5
mode: "l2"
model_name_pairs:
- ["Student", "Teacher"]
key: hidden_states_5
name: "loss_5"
- DistillationVQADistanceLoss:
weight: 0.5
mode: "l2"
model_name_pairs:
- ["Student", "Teacher"]
key: hidden_states_8
name: "loss_8"
Optimizer:
name: AdamW
beta1: 0.9
beta2: 0.999
clip_norm: 10
lr:
learning_rate: 0.00005
warmup_epoch: 10
regularizer:
name: L2
factor: 0.00000
PostProcess:
name: DistillationRePostProcess
model_name: ["Student", "Teacher"]
key: null
Metric:
name: DistillationMetric
base_metric_name: VQAReTokenMetric
main_indicator: hmean
key: "Student"
Train:
dataset:
name: SimpleDataSet
data_dir: train_data/XFUND/zh_train/image
label_file_list:
- train_data/XFUND/zh_train/train.json
ratio_list: [ 1.0 ]
transforms:
- DecodeImage: # load image
img_mode: RGB
channel_first: False
- VQATokenLabelEncode: # Class handling label
contains_re: True
algorithm: *algorithm
class_path: &class_path train_data/XFUND/class_list_xfun.txt
use_textline_bbox_info: &use_textline_bbox_info True
# [None, "tb-yx"]
order_method: &order_method "tb-yx"
- VQATokenPad:
max_seq_len: &max_seq_len 512
return_attention_mask: True
- VQAReTokenRelation:
- VQAReTokenChunk:
max_seq_len: *max_seq_len
- Resize:
size: [224,224]
- NormalizeImage:
scale: 1
mean: [ 123.675, 116.28, 103.53 ]
std: [ 58.395, 57.12, 57.375 ]
order: 'hwc'
- ToCHWImage:
- KeepKeys:
keep_keys: [ 'input_ids', 'bbox','attention_mask', 'token_type_ids', 'image', 'entities', 'relations'] # dataloader will return list in this order
loader:
shuffle: True
drop_last: False
batch_size_per_card: 2
num_workers: 4
collate_fn: ListCollator
Eval:
dataset:
name: SimpleDataSet
data_dir: train_data/XFUND/zh_val/image
label_file_list:
- train_data/XFUND/zh_val/val.json
transforms:
- DecodeImage: # load image
img_mode: RGB
channel_first: False
- VQATokenLabelEncode: # Class handling label
contains_re: True
algorithm: *algorithm
class_path: *class_path
use_textline_bbox_info: *use_textline_bbox_info
order_method: *order_method
- VQATokenPad:
max_seq_len: *max_seq_len
return_attention_mask: True
- VQAReTokenRelation:
- VQAReTokenChunk:
max_seq_len: *max_seq_len
- Resize:
size: [224,224]
- NormalizeImage:
scale: 1
mean: [ 123.675, 116.28, 103.53 ]
std: [ 58.395, 57.12, 57.375 ]
order: 'hwc'
- ToCHWImage:
- KeepKeys:
keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image', 'entities', 'relations'] # dataloader will return list in this order
loader:
shuffle: False
drop_last: False
batch_size_per_card: 8
num_workers: 8
collate_fn: ListCollator
......@@ -3,30 +3,38 @@ Global:
epoch_num: &epoch_num 200
log_smooth_window: 10
print_batch_step: 10
save_model_dir: ./output/ser_layoutlm_funsd
save_model_dir: ./output/ser_vi_layoutxlm_xfund_zh
save_epoch_step: 2000
# evaluation is run every 10 iterations after the 0th iteration
eval_batch_step: [ 0, 57 ]
eval_batch_step: [ 0, 19 ]
cal_metric_during_train: False
save_inference_dir:
use_visualdl: False
seed: 2022
infer_img: train_data/FUNSD/testing_data/images/83624198.png
save_res_path: ./output/ser_layoutlm_funsd/res/
infer_img: ppstructure/docs/vqa/input/zh_val_42.jpg
# if you want to predict using the groundtruth ocr info,
# you can use the following config
# infer_img: train_data/XFUND/zh_val/val.json
# infer_mode: False
save_res_path: ./output/ser/xfund_zh/res
Architecture:
model_type: vqa
algorithm: &algorithm "LayoutLM"
algorithm: &algorithm "LayoutXLM"
Transform:
Backbone:
name: LayoutLMForSer
name: LayoutXLMForSer
pretrained: True
checkpoints:
# one of base or vi
mode: vi
num_classes: &num_classes 7
Loss:
name: VQASerTokenLayoutLMLoss
num_classes: *num_classes
key: "backbone_out"
Optimizer:
name: AdamW
......@@ -43,7 +51,7 @@ Optimizer:
PostProcess:
name: VQASerTokenLayoutLMPostProcess
class_path: &class_path ./train_data/FUNSD/class_list.txt
class_path: &class_path train_data/XFUND/class_list_xfun.txt
Metric:
name: VQASerTokenMetric
......@@ -52,9 +60,10 @@ Metric:
Train:
dataset:
name: SimpleDataSet
data_dir: ./train_data/FUNSD/training_data/images/
data_dir: train_data/XFUND/zh_train/image
label_file_list:
- ./train_data/FUNSD/train.json
- train_data/XFUND/zh_train/train.json
ratio_list: [ 1.0 ]
transforms:
- DecodeImage: # load image
img_mode: RGB
......@@ -64,6 +73,8 @@ Train:
algorithm: *algorithm
class_path: *class_path
use_textline_bbox_info: &use_textline_bbox_info True
# one of [None, "tb-yx"]
order_method: &order_method "tb-yx"
- VQATokenPad:
max_seq_len: &max_seq_len 512
return_attention_mask: True
......@@ -78,8 +89,7 @@ Train:
order: 'hwc'
- ToCHWImage:
- KeepKeys:
# dataloader will return list in this order
keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image', 'labels']
keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image', 'labels'] # dataloader will return list in this order
loader:
shuffle: True
drop_last: False
......@@ -89,9 +99,9 @@ Train:
Eval:
dataset:
name: SimpleDataSet
data_dir: train_data/FUNSD/testing_data/images/
data_dir: train_data/XFUND/zh_val/image
label_file_list:
- ./train_data/FUNSD/test.json
- train_data/XFUND/zh_val/val.json
transforms:
- DecodeImage: # load image
img_mode: RGB
......@@ -101,6 +111,7 @@ Eval:
algorithm: *algorithm
class_path: *class_path
use_textline_bbox_info: *use_textline_bbox_info
order_method: *order_method
- VQATokenPad:
max_seq_len: *max_seq_len
return_attention_mask: True
......@@ -115,8 +126,7 @@ Eval:
order: 'hwc'
- ToCHWImage:
- KeepKeys:
# dataloader will return list in this order
keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image', 'labels']
keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image', 'labels'] # dataloader will return list in this order
loader:
shuffle: False
drop_last: False
......
......@@ -3,30 +3,84 @@ Global:
epoch_num: &epoch_num 200
log_smooth_window: 10
print_batch_step: 10
save_model_dir: ./output/ser_layoutxlm_funsd
save_model_dir: ./output/ser_vi_layoutxlm_xfund_zh_udml
save_epoch_step: 2000
# evaluation is run every 10 iterations after the 0th iteration
eval_batch_step: [ 0, 57 ]
eval_batch_step: [ 0, 19 ]
cal_metric_during_train: False
save_inference_dir:
use_visualdl: False
seed: 2022
infer_img: train_data/FUNSD/testing_data/images/83624198.png
save_res_path: output/ser_layoutxlm_funsd/res/
infer_img: ppstructure/docs/vqa/input/zh_val_42.jpg
save_res_path: ./output/ser_layoutxlm_xfund_zh/res
Architecture:
model_type: vqa
algorithm: &algorithm "LayoutXLM"
Transform:
Backbone:
name: LayoutXLMForSer
pretrained: True
checkpoints:
num_classes: &num_classes 7
model_type: &model_type "vqa"
name: DistillationModel
algorithm: Distillation
Models:
Teacher:
pretrained:
freeze_params: false
return_all_feats: true
model_type: *model_type
algorithm: &algorithm "LayoutXLM"
Transform:
Backbone:
name: LayoutXLMForSer
pretrained: True
# one of base or vi
mode: vi
checkpoints:
num_classes: &num_classes 7
Student:
pretrained:
freeze_params: false
return_all_feats: true
model_type: *model_type
algorithm: *algorithm
Transform:
Backbone:
name: LayoutXLMForSer
pretrained: True
# one of base or vi
mode: vi
checkpoints:
num_classes: *num_classes
Loss:
name: VQASerTokenLayoutLMLoss
num_classes: *num_classes
name: CombinedLoss
loss_config_list:
- DistillationVQASerTokenLayoutLMLoss:
weight: 1.0
model_name_list: ["Student", "Teacher"]
key: backbone_out
num_classes: *num_classes
- DistillationSERDMLLoss:
weight: 1.0
act: "softmax"
use_log: true
model_name_pairs:
- ["Student", "Teacher"]
key: backbone_out
- DistillationVQADistanceLoss:
weight: 0.5
mode: "l2"
model_name_pairs:
- ["Student", "Teacher"]
key: hidden_states_5
name: "loss_5"
- DistillationVQADistanceLoss:
weight: 0.5
mode: "l2"
model_name_pairs:
- ["Student", "Teacher"]
key: hidden_states_8
name: "loss_8"
Optimizer:
name: AdamW
......@@ -36,25 +90,29 @@ Optimizer:
name: Linear
learning_rate: 0.00005
epochs: *epoch_num
warmup_epoch: 2
warmup_epoch: 10
regularizer:
name: L2
factor: 0.00000
PostProcess:
name: VQASerTokenLayoutLMPostProcess
class_path: &class_path ./train_data/FUNSD/class_list.txt
name: DistillationSerPostProcess
model_name: ["Student", "Teacher"]
key: backbone_out
class_path: &class_path train_data/XFUND/class_list_xfun.txt
Metric:
name: VQASerTokenMetric
name: DistillationMetric
base_metric_name: VQASerTokenMetric
main_indicator: hmean
key: "Student"
Train:
dataset:
name: SimpleDataSet
data_dir: ./train_data/FUNSD/training_data/images/
data_dir: train_data/XFUND/zh_train/image
label_file_list:
- ./train_data/FUNSD/train.json
- train_data/XFUND/zh_train/train.json
ratio_list: [ 1.0 ]
transforms:
- DecodeImage: # load image
......@@ -64,6 +122,8 @@ Train:
contains_re: False
algorithm: *algorithm
class_path: *class_path
# one of [None, "tb-yx"]
order_method: &order_method "tb-yx"
- VQATokenPad:
max_seq_len: &max_seq_len 512
return_attention_mask: True
......@@ -78,20 +138,19 @@ Train:
order: 'hwc'
- ToCHWImage:
- KeepKeys:
# dataloader will return list in this order
keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image', 'labels']
keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image', 'labels'] # dataloader will return list in this order
loader:
shuffle: True
drop_last: False
batch_size_per_card: 8
batch_size_per_card: 4
num_workers: 4
Eval:
dataset:
name: SimpleDataSet
data_dir: train_data/FUNSD/testing_data/images/
data_dir: train_data/XFUND/zh_val/image
label_file_list:
- ./train_data/FUNSD/test.json
- train_data/XFUND/zh_val/val.json
transforms:
- DecodeImage: # load image
img_mode: RGB
......@@ -100,6 +159,7 @@ Eval:
contains_re: False
algorithm: *algorithm
class_path: *class_path
order_method: *order_method
- VQATokenPad:
max_seq_len: *max_seq_len
return_attention_mask: True
......@@ -114,10 +174,10 @@ Eval:
order: 'hwc'
- ToCHWImage:
- KeepKeys:
# dataloader will return list in this order
keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image', 'labels']
keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image', 'labels'] # dataloader will return list in this order
loader:
shuffle: False
drop_last: False
batch_size_per_card: 8
num_workers: 4
Global:
use_gpu: True
epoch_num: &epoch_num 200
log_smooth_window: 10
print_batch_step: 10
save_model_dir: ./output/re_layoutlmv2_funsd
save_epoch_step: 2000
# evaluation is run every 10 iterations after the 0th iteration
eval_batch_step: [ 0, 57 ]
cal_metric_during_train: False
save_inference_dir:
use_visualdl: False
seed: 2022
infer_img: train_data/FUNSD/testing_data/images/83624198.png
save_res_path: ./output/re_layoutlmv2_funsd/res/
Architecture:
model_type: vqa
algorithm: &algorithm "LayoutLMv2"
Transform:
Backbone:
name: LayoutLMv2ForRe
pretrained: True
checkpoints:
Loss:
name: LossFromOutput
key: loss
reduction: mean
Optimizer:
name: AdamW
beta1: 0.9
beta2: 0.999
clip_norm: 10
lr:
learning_rate: 0.00005
warmup_epoch: 10
regularizer:
name: L2
factor: 0.00000
PostProcess:
name: VQAReTokenLayoutLMPostProcess
Metric:
name: VQAReTokenMetric
main_indicator: hmean
Train:
dataset:
name: SimpleDataSet
data_dir: ./train_data/FUNSD/training_data/images/
label_file_list:
- ./train_data/FUNSD/train.json
ratio_list: [ 1.0 ]
transforms:
- DecodeImage: # load image
img_mode: RGB
channel_first: False
- VQATokenLabelEncode: # Class handling label
contains_re: True
algorithm: *algorithm
class_path: &class_path train_data/FUNSD/class_list.txt
- VQATokenPad:
max_seq_len: &max_seq_len 512
return_attention_mask: True
- VQAReTokenRelation:
- VQAReTokenChunk:
max_seq_len: *max_seq_len
- Resize:
size: [224,224]
- NormalizeImage:
scale: 1./255.
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: 'hwc'
- ToCHWImage:
- KeepKeys:
# dataloader will return list in this order
keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image', 'entities', 'relations']
loader:
shuffle: True
drop_last: False
batch_size_per_card: 8
num_workers: 8
collate_fn: ListCollator
Eval:
dataset:
name: SimpleDataSet
data_dir: ./train_data/FUNSD/testing_data/images/
label_file_list:
- ./train_data/FUNSD/test.json
transforms:
- DecodeImage: # load image
img_mode: RGB
channel_first: False
- VQATokenLabelEncode: # Class handling label
contains_re: True
algorithm: *algorithm
class_path: *class_path
- VQATokenPad:
max_seq_len: *max_seq_len
return_attention_mask: True
- VQAReTokenRelation:
- VQAReTokenChunk:
max_seq_len: *max_seq_len
- Resize:
size: [224,224]
- NormalizeImage:
scale: 1./255.
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: 'hwc'
- ToCHWImage:
- KeepKeys:
# dataloader will return list in this order
keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image', 'entities', 'relations']
loader:
shuffle: False
drop_last: False
batch_size_per_card: 8
num_workers: 8
collate_fn: ListCollator
Global:
use_gpu: True
epoch_num: &epoch_num 200
log_smooth_window: 10
print_batch_step: 10
save_model_dir: ./output/ser_layoutlm_sroie
save_epoch_step: 2000
# evaluation is run every 10 iterations after the 0th iteration
eval_batch_step: [ 0, 200 ]
cal_metric_during_train: False
save_inference_dir:
use_visualdl: False
seed: 2022
infer_img: train_data/SROIE/test/X00016469670.jpg
save_res_path: ./output/ser_layoutlm_sroie/res/
Architecture:
model_type: vqa
algorithm: &algorithm "LayoutLM"
Transform:
Backbone:
name: LayoutLMForSer
pretrained: True
checkpoints:
num_classes: &num_classes 9
Loss:
name: VQASerTokenLayoutLMLoss
num_classes: *num_classes
Optimizer:
name: AdamW
beta1: 0.9
beta2: 0.999
lr:
name: Linear
learning_rate: 0.00005
epochs: *epoch_num
warmup_epoch: 2
regularizer:
name: L2
factor: 0.00000
PostProcess:
name: VQASerTokenLayoutLMPostProcess
class_path: &class_path ./train_data/SROIE/class_list.txt
Metric:
name: VQASerTokenMetric
main_indicator: hmean
Train:
dataset:
name: SimpleDataSet
data_dir: ./train_data/SROIE/train
label_file_list:
- ./train_data/SROIE/train.txt
transforms:
- DecodeImage: # load image
img_mode: RGB
channel_first: False
- VQATokenLabelEncode: # Class handling label
contains_re: False
algorithm: *algorithm
class_path: *class_path
use_textline_bbox_info: &use_textline_bbox_info True
- VQATokenPad:
max_seq_len: &max_seq_len 512
return_attention_mask: True
- VQASerTokenChunk:
max_seq_len: *max_seq_len
- Resize:
size: [224,224]
- NormalizeImage:
scale: 1
mean: [ 123.675, 116.28, 103.53 ]
std: [ 58.395, 57.12, 57.375 ]
order: 'hwc'
- ToCHWImage:
- KeepKeys:
# dataloader will return list in this order
keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image', 'labels']
loader:
shuffle: True
drop_last: False
batch_size_per_card: 8
num_workers: 4
Eval:
dataset:
name: SimpleDataSet
data_dir: ./train_data/SROIE/test
label_file_list:
- ./train_data/SROIE/test.txt
transforms:
- DecodeImage: # load image
img_mode: RGB
channel_first: False
- VQATokenLabelEncode: # Class handling label
contains_re: False
algorithm: *algorithm
class_path: *class_path
use_textline_bbox_info: *use_textline_bbox_info
- VQATokenPad:
max_seq_len: *max_seq_len
return_attention_mask: True
- VQASerTokenChunk:
max_seq_len: *max_seq_len
- Resize:
size: [224,224]
- NormalizeImage:
scale: 1
mean: [ 123.675, 116.28, 103.53 ]
std: [ 58.395, 57.12, 57.375 ]
order: 'hwc'
- ToCHWImage:
- KeepKeys:
# dataloader will return list in this order
keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image', 'labels']
loader:
shuffle: False
drop_last: False
batch_size_per_card: 8
num_workers: 4
Global:
use_gpu: True
epoch_num: &epoch_num 200
log_smooth_window: 10
print_batch_step: 10
save_model_dir: ./output/ser_layoutlmv2_funsd
save_epoch_step: 2000
# evaluation is run every 10 iterations after the 0th iteration
eval_batch_step: [ 0, 100 ]
cal_metric_during_train: False
save_inference_dir:
use_visualdl: False
seed: 2022
infer_img: train_data/FUNSD/testing_data/images/83624198.png
save_res_path: ./output/ser_layoutlmv2_funsd/res/
Architecture:
model_type: vqa
algorithm: &algorithm "LayoutLMv2"
Transform:
Backbone:
name: LayoutLMv2ForSer
pretrained: True
checkpoints:
num_classes: &num_classes 7
Loss:
name: VQASerTokenLayoutLMLoss
num_classes: *num_classes
Optimizer:
name: AdamW
beta1: 0.9
beta2: 0.999
lr:
name: Linear
learning_rate: 0.00005
epochs: *epoch_num
warmup_epoch: 2
regularizer:
name: L2
factor: 0.00000
PostProcess:
name: VQASerTokenLayoutLMPostProcess
class_path: &class_path train_data/FUNSD/class_list.txt
Metric:
name: VQASerTokenMetric
main_indicator: hmean
Train:
dataset:
name: SimpleDataSet
data_dir: ./train_data/FUNSD/training_data/images/
label_file_list:
- ./train_data/FUNSD/train.json
transforms:
- DecodeImage: # load image
img_mode: RGB
channel_first: False
- VQATokenLabelEncode: # Class handling label
contains_re: False
algorithm: *algorithm
class_path: *class_path
- VQATokenPad:
max_seq_len: &max_seq_len 512
return_attention_mask: True
- VQASerTokenChunk:
max_seq_len: *max_seq_len
- Resize:
size: [224,224]
- NormalizeImage:
scale: 1
mean: [ 123.675, 116.28, 103.53 ]
std: [ 58.395, 57.12, 57.375 ]
order: 'hwc'
- ToCHWImage:
- KeepKeys:
# dataloader will return list in this order
keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image', 'labels']
loader:
shuffle: True
drop_last: False
batch_size_per_card: 8
num_workers: 4
Eval:
dataset:
name: SimpleDataSet
data_dir: ./train_data/FUNSD/testing_data/images/
label_file_list:
- ./train_data/FUNSD/test.json
transforms:
- DecodeImage: # load image
img_mode: RGB
channel_first: False
- VQATokenLabelEncode: # Class handling label
contains_re: False
algorithm: *algorithm
class_path: *class_path
- VQATokenPad:
max_seq_len: *max_seq_len
return_attention_mask: True
- VQASerTokenChunk:
max_seq_len: *max_seq_len
- Resize:
size: [224,224]
- NormalizeImage:
scale: 1
mean: [ 123.675, 116.28, 103.53 ]
std: [ 58.395, 57.12, 57.375 ]
order: 'hwc'
- ToCHWImage:
- KeepKeys:
# dataloader will return list in this order
keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image', 'labels']
loader:
shuffle: False
drop_last: False
batch_size_per_card: 8
num_workers: 4
Global:
use_gpu: True
epoch_num: &epoch_num 200
log_smooth_window: 10
print_batch_step: 10
save_model_dir: ./output/ser_layoutlmv2_sroie
save_epoch_step: 2000
# evaluation is run every 10 iterations after the 0th iteration
eval_batch_step: [ 0, 200 ]
cal_metric_during_train: False
save_inference_dir:
use_visualdl: False
seed: 2022
infer_img: train_data/SROIE/test/X00016469670.jpg
save_res_path: ./output/ser_layoutlmv2_sroie/res/
Architecture:
model_type: vqa
algorithm: &algorithm "LayoutLMv2"
Transform:
Backbone:
name: LayoutLMv2ForSer
pretrained: True
checkpoints:
num_classes: &num_classes 9
Loss:
name: VQASerTokenLayoutLMLoss
num_classes: *num_classes
Optimizer:
name: AdamW
beta1: 0.9
beta2: 0.999
lr:
name: Linear
learning_rate: 0.00005
epochs: *epoch_num
warmup_epoch: 2
regularizer:
name: L2
factor: 0.00000
PostProcess:
name: VQASerTokenLayoutLMPostProcess
class_path: &class_path ./train_data/SROIE/class_list.txt
Metric:
name: VQASerTokenMetric
main_indicator: hmean
Train:
dataset:
name: SimpleDataSet
data_dir: ./train_data/SROIE/train
label_file_list:
- ./train_data/SROIE/train.txt
transforms:
- DecodeImage: # load image
img_mode: RGB
channel_first: False
- VQATokenLabelEncode: # Class handling label
contains_re: False
algorithm: *algorithm
class_path: *class_path
- VQATokenPad:
max_seq_len: &max_seq_len 512
return_attention_mask: True
- VQASerTokenChunk:
max_seq_len: *max_seq_len
- Resize:
size: [224,224]
- NormalizeImage:
scale: 1
mean: [ 123.675, 116.28, 103.53 ]
std: [ 58.395, 57.12, 57.375 ]
order: 'hwc'
- ToCHWImage:
- KeepKeys:
# dataloader will return list in this order
keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image', 'labels']
loader:
shuffle: True
drop_last: False
batch_size_per_card: 8
num_workers: 4
Eval:
dataset:
name: SimpleDataSet
data_dir: ./train_data/SROIE/test
label_file_list:
- ./train_data/SROIE/test.txt
transforms:
- DecodeImage: # load image
img_mode: RGB
channel_first: False
- VQATokenLabelEncode: # Class handling label
contains_re: False
algorithm: *algorithm
class_path: *class_path
- VQATokenPad:
max_seq_len: *max_seq_len
return_attention_mask: True
- VQASerTokenChunk:
max_seq_len: *max_seq_len
- Resize:
size: [224,224]
- NormalizeImage:
scale: 1
mean: [ 123.675, 116.28, 103.53 ]
std: [ 58.395, 57.12, 57.375 ]
order: 'hwc'
- ToCHWImage:
- KeepKeys:
# dataloader will return list in this order
keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image', 'labels']
loader:
shuffle: False
drop_last: False
batch_size_per_card: 8
num_workers: 4
Global:
use_gpu: True
epoch_num: &epoch_num 200
log_smooth_window: 10
print_batch_step: 10
save_model_dir: ./output/ser_layoutxlm_sroie
save_epoch_step: 2000
# evaluation is run every 10 iterations after the 0th iteration
eval_batch_step: [ 0, 200 ]
cal_metric_during_train: False
save_inference_dir:
use_visualdl: False
seed: 2022
infer_img: train_data/SROIE/test/X00016469670.jpg
save_res_path: res_img_aug_with_gt
Architecture:
model_type: vqa
algorithm: &algorithm "LayoutXLM"
Transform:
Backbone:
name: LayoutXLMForSer
pretrained: True
checkpoints:
num_classes: &num_classes 9
Loss:
name: VQASerTokenLayoutLMLoss
num_classes: *num_classes
Optimizer:
name: AdamW
beta1: 0.9
beta2: 0.999
lr:
name: Linear
learning_rate: 0.00005
epochs: *epoch_num
warmup_epoch: 2
regularizer:
name: L2
factor: 0.00000
PostProcess:
name: VQASerTokenLayoutLMPostProcess
class_path: &class_path ./train_data/SROIE/class_list.txt
Metric:
name: VQASerTokenMetric
main_indicator: hmean
Train:
dataset:
name: SimpleDataSet
data_dir: ./train_data/SROIE/train
label_file_list:
- ./train_data/SROIE/train.txt
ratio_list: [ 1.0 ]
transforms:
- DecodeImage: # load image
img_mode: RGB
channel_first: False
- VQATokenLabelEncode: # Class handling label
contains_re: False
algorithm: *algorithm
class_path: *class_path
- VQATokenPad:
max_seq_len: &max_seq_len 512
return_attention_mask: True
- VQASerTokenChunk:
max_seq_len: *max_seq_len
- Resize:
size: [224,224]
- NormalizeImage:
scale: 1
mean: [ 123.675, 116.28, 103.53 ]
std: [ 58.395, 57.12, 57.375 ]
order: 'hwc'
- ToCHWImage:
- KeepKeys:
# dataloader will return list in this order
keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image', 'labels']
loader:
shuffle: True
drop_last: False
batch_size_per_card: 8
num_workers: 4
Eval:
dataset:
name: SimpleDataSet
data_dir: train_data/SROIE/test
label_file_list:
- ./train_data/SROIE/test.txt
transforms:
- DecodeImage: # load image
img_mode: RGB
channel_first: False
- VQATokenLabelEncode: # Class handling label
contains_re: False
algorithm: *algorithm
class_path: *class_path
- VQATokenPad:
max_seq_len: *max_seq_len
return_attention_mask: True
- VQASerTokenChunk:
max_seq_len: *max_seq_len
- Resize:
size: [224,224]
- NormalizeImage:
scale: 1
mean: [ 123.675, 116.28, 103.53 ]
std: [ 58.395, 57.12, 57.375 ]
order: 'hwc'
- ToCHWImage:
- KeepKeys:
# dataloader will return list in this order
keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image', 'labels']
loader:
shuffle: False
drop_last: False
batch_size_per_card: 8
num_workers: 4
Global:
use_gpu: True
epoch_num: &epoch_num 100
log_smooth_window: 10
print_batch_step: 10
save_model_dir: ./output/ser_layoutxlm_wildreceipt
save_epoch_step: 2000
# evaluation is run every 10 iterations after the 0th iteration
eval_batch_step: [ 0, 200 ]
cal_metric_during_train: False
save_inference_dir:
use_visualdl: False
seed: 2022
infer_img: train_data//wildreceipt/image_files/Image_12/10/845be0dd6f5b04866a2042abd28d558032ef2576.jpeg
save_res_path: ./output/ser_layoutxlm_wildreceipt/res
Architecture:
model_type: vqa
algorithm: &algorithm "LayoutXLM"
Transform:
Backbone:
name: LayoutXLMForSer
pretrained: True
checkpoints:
num_classes: &num_classes 51
Loss:
name: VQASerTokenLayoutLMLoss
num_classes: *num_classes
Optimizer:
name: AdamW
beta1: 0.9
beta2: 0.999
lr:
name: Linear
learning_rate: 0.00005
epochs: *epoch_num
warmup_epoch: 2
regularizer:
name: L2
factor: 0.00000
PostProcess:
name: VQASerTokenLayoutLMPostProcess
class_path: &class_path ./train_data/wildreceipt/class_list.txt
Metric:
name: VQASerTokenMetric
main_indicator: hmean
Train:
dataset:
name: SimpleDataSet
data_dir: ./train_data/wildreceipt/
label_file_list:
- ./train_data/wildreceipt/wildreceipt_train.txt
ratio_list: [ 1.0 ]
transforms:
- DecodeImage: # load image
img_mode: RGB
channel_first: False
- VQATokenLabelEncode: # Class handling label
contains_re: False
algorithm: *algorithm
class_path: *class_path
- VQATokenPad:
max_seq_len: &max_seq_len 512
return_attention_mask: True
- VQASerTokenChunk:
max_seq_len: *max_seq_len
- Resize:
size: [224,224]
- NormalizeImage:
scale: 1
mean: [ 123.675, 116.28, 103.53 ]
std: [ 58.395, 57.12, 57.375 ]
order: 'hwc'
- ToCHWImage:
- KeepKeys:
# dataloader will return list in this order
keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image', 'labels']
loader:
shuffle: True
drop_last: False
batch_size_per_card: 8
num_workers: 4
Eval:
dataset:
name: SimpleDataSet
data_dir: train_data/wildreceipt
label_file_list:
- ./train_data/wildreceipt/wildreceipt_test.txt
transforms:
- DecodeImage: # load image
img_mode: RGB
channel_first: False
- VQATokenLabelEncode: # Class handling label
contains_re: False
algorithm: *algorithm
class_path: *class_path
- VQATokenPad:
max_seq_len: *max_seq_len
return_attention_mask: True
- VQASerTokenChunk:
max_seq_len: *max_seq_len
- Resize:
size: [224,224]
- NormalizeImage:
scale: 1
mean: [ 123.675, 116.28, 103.53 ]
std: [ 58.395, 57.12, 57.375 ]
order: 'hwc'
- ToCHWImage:
- KeepKeys:
# dataloader will return list in this order
keep_keys: [ 'input_ids', 'bbox', 'attention_mask', 'token_type_ids', 'image', 'labels']
loader:
shuffle: False
drop_last: False
batch_size_per_card: 8
num_workers: 4
......@@ -24,6 +24,7 @@ from shapely.geometry import LineString, Point, Polygon
import json
import copy
from ppocr.utils.logging import get_logger
from ppocr.data.imaug.vqa.augment import order_by_tbyx
class ClsLabelEncode(object):
......@@ -870,6 +871,7 @@ class VQATokenLabelEncode(object):
add_special_ids=False,
algorithm='LayoutXLM',
use_textline_bbox_info=True,
order_method=None,
infer_mode=False,
ocr_engine=None,
**kwargs):
......@@ -899,6 +901,8 @@ class VQATokenLabelEncode(object):
self.infer_mode = infer_mode
self.ocr_engine = ocr_engine
self.use_textline_bbox_info = use_textline_bbox_info
self.order_method = order_method
assert self.order_method in [None, "tb-yx"]
def split_bbox(self, bbox, text, tokenizer):
words = text.split()
......@@ -938,6 +942,14 @@ class VQATokenLabelEncode(object):
# load bbox and label info
ocr_info = self._load_ocr_info(data)
for idx in range(len(ocr_info)):
if "bbox" not in ocr_info[idx]:
ocr_info[idx]["bbox"] = self.trans_poly_to_bbox(ocr_info[idx][
"points"])
if self.order_method == "tb-yx":
ocr_info = order_by_tbyx(ocr_info)
# for re
train_re = self.contains_re and not self.infer_mode
if train_re:
......@@ -1052,10 +1064,10 @@ class VQATokenLabelEncode(object):
return data
def trans_poly_to_bbox(self, poly):
x1 = np.min([p[0] for p in poly])
x2 = np.max([p[0] for p in poly])
y1 = np.min([p[1] for p in poly])
y2 = np.max([p[1] for p in poly])
x1 = int(np.min([p[0] for p in poly]))
x2 = int(np.max([p[0] for p in poly]))
y1 = int(np.min([p[1] for p in poly]))
y2 = int(np.max([p[1] for p in poly]))
return [x1, y1, x2, y2]
def _load_ocr_info(self, data):
......
......@@ -13,12 +13,10 @@
# limitations under the License.
from .token import VQATokenPad, VQASerTokenChunk, VQAReTokenChunk, VQAReTokenRelation
from .augment import DistortBBox
__all__ = [
'VQATokenPad',
'VQASerTokenChunk',
'VQAReTokenChunk',
'VQAReTokenRelation',
'DistortBBox',
]
......@@ -16,22 +16,18 @@ import os
import sys
import numpy as np
import random
from copy import deepcopy
class DistortBBox:
def __init__(self, prob=0.5, max_scale=1, **kwargs):
"""Random distort bbox
"""
self.prob = prob
self.max_scale = max_scale
def __call__(self, data):
if random.random() > self.prob:
return data
bbox = np.array(data['bbox'])
rnd_scale = (np.random.rand(*bbox.shape) - 0.5) * 2 * self.max_scale
bbox = np.round(bbox + rnd_scale).astype(bbox.dtype)
data['bbox'] = np.clip(data['bbox'], 0, 1000)
data['bbox'] = bbox.tolist()
sys.stdout.flush()
return data
def order_by_tbyx(ocr_info):
res = sorted(ocr_info, key=lambda r: (r["bbox"][1], r["bbox"][0]))
for i in range(len(res) - 1):
for j in range(i, 0, -1):
if abs(res[j + 1]["bbox"][1] - res[j]["bbox"][1]) < 20 and \
(res[j + 1]["bbox"][0] < res[j]["bbox"][0]):
tmp = deepcopy(res[j])
res[j] = deepcopy(res[j + 1])
res[j + 1] = deepcopy(tmp)
else:
break
return res
......@@ -63,18 +63,21 @@ class KLJSLoss(object):
def __call__(self, p1, p2, reduction="mean"):
if self.mode.lower() == 'kl':
loss = paddle.multiply(p2, paddle.log((p2 + 1e-5) / (p1 + 1e-5) + 1e-5))
loss = paddle.multiply(p2,
paddle.log((p2 + 1e-5) / (p1 + 1e-5) + 1e-5))
loss += paddle.multiply(
p1, paddle.log((p1 + 1e-5) / (p2 + 1e-5) + 1e-5))
p1, paddle.log((p1 + 1e-5) / (p2 + 1e-5) + 1e-5))
loss *= 0.5
elif self.mode.lower() == "js":
loss = paddle.multiply(p2, paddle.log((2*p2 + 1e-5) / (p1 + p2 + 1e-5) + 1e-5))
loss = paddle.multiply(
p2, paddle.log((2 * p2 + 1e-5) / (p1 + p2 + 1e-5) + 1e-5))
loss += paddle.multiply(
p1, paddle.log((2*p1 + 1e-5) / (p1 + p2 + 1e-5) + 1e-5))
p1, paddle.log((2 * p1 + 1e-5) / (p1 + p2 + 1e-5) + 1e-5))
loss *= 0.5
else:
raise ValueError("The mode.lower() if KLJSLoss should be one of ['kl', 'js']")
raise ValueError(
"The mode.lower() if KLJSLoss should be one of ['kl', 'js']")
if reduction == "mean":
loss = paddle.mean(loss, axis=[1, 2])
elif reduction == "none" or reduction is None:
......@@ -154,7 +157,9 @@ class LossFromOutput(nn.Layer):
self.reduction = reduction
def forward(self, predicts, batch):
loss = predicts[self.key]
loss = predicts
if self.key is not None and isinstance(predicts, dict):
loss = loss[self.key]
if self.reduction == 'mean':
loss = paddle.mean(loss)
elif self.reduction == 'sum':
......
......@@ -24,6 +24,9 @@ from .distillation_loss import DistillationCTCLoss
from .distillation_loss import DistillationSARLoss
from .distillation_loss import DistillationDMLLoss
from .distillation_loss import DistillationDistanceLoss, DistillationDBLoss, DistillationDilaDBLoss
from .distillation_loss import DistillationVQASerTokenLayoutLMLoss, DistillationSERDMLLoss
from .distillation_loss import DistillationLossFromOutput
from .distillation_loss import DistillationVQADistanceLoss
class CombinedLoss(nn.Layer):
......
......@@ -21,8 +21,10 @@ from .rec_ctc_loss import CTCLoss
from .rec_sar_loss import SARLoss
from .basic_loss import DMLLoss
from .basic_loss import DistanceLoss
from .basic_loss import LossFromOutput
from .det_db_loss import DBLoss
from .det_basic_loss import BalanceLoss, MaskL1Loss, DiceLoss
from .vqa_token_layoutlm_loss import VQASerTokenLayoutLMLoss
def _sum_loss(loss_dict):
......@@ -322,3 +324,133 @@ class DistillationDistanceLoss(DistanceLoss):
loss_dict["{}_{}_{}_{}".format(self.name, pair[0], pair[1],
idx)] = loss
return loss_dict
class DistillationVQASerTokenLayoutLMLoss(VQASerTokenLayoutLMLoss):
def __init__(self,
num_classes,
model_name_list=[],
key=None,
name="loss_ser"):
super().__init__(num_classes=num_classes)
self.model_name_list = model_name_list
self.key = key
self.name = name
def forward(self, predicts, batch):
loss_dict = dict()
for idx, model_name in enumerate(self.model_name_list):
out = predicts[model_name]
if self.key is not None:
out = out[self.key]
loss = super().forward(out, batch)
loss_dict["{}_{}".format(self.name, model_name)] = loss["loss"]
return loss_dict
class DistillationLossFromOutput(LossFromOutput):
def __init__(self,
reduction="none",
model_name_list=[],
dist_key=None,
key="loss",
name="loss_re"):
super().__init__(key=key, reduction=reduction)
self.model_name_list = model_name_list
self.name = name
self.dist_key = dist_key
def forward(self, predicts, batch):
loss_dict = dict()
for idx, model_name in enumerate(self.model_name_list):
out = predicts[model_name]
if self.dist_key is not None:
out = out[self.dist_key]
loss = super().forward(out, batch)
loss_dict["{}_{}".format(self.name, model_name)] = loss["loss"]
return loss_dict
class DistillationSERDMLLoss(DMLLoss):
"""
"""
def __init__(self,
act="softmax",
use_log=True,
num_classes=7,
model_name_pairs=[],
key=None,
name="loss_dml_ser"):
super().__init__(act=act, use_log=use_log)
assert isinstance(model_name_pairs, list)
self.key = key
self.name = name
self.num_classes = num_classes
self.model_name_pairs = model_name_pairs
def forward(self, predicts, batch):
loss_dict = dict()
for idx, pair in enumerate(self.model_name_pairs):
out1 = predicts[pair[0]]
out2 = predicts[pair[1]]
if self.key is not None:
out1 = out1[self.key]
out2 = out2[self.key]
out1 = out1.reshape([-1, out1.shape[-1]])
out2 = out2.reshape([-1, out2.shape[-1]])
attention_mask = batch[2]
if attention_mask is not None:
active_output = attention_mask.reshape([-1, ]) == 1
out1 = out1[active_output]
out2 = out2[active_output]
loss_dict["{}_{}".format(self.name, idx)] = super().forward(out1,
out2)
return loss_dict
class DistillationVQADistanceLoss(DistanceLoss):
def __init__(self,
mode="l2",
model_name_pairs=[],
key=None,
name="loss_distance",
**kargs):
super().__init__(mode=mode, **kargs)
assert isinstance(model_name_pairs, list)
self.key = key
self.model_name_pairs = model_name_pairs
self.name = name + "_l2"
def forward(self, predicts, batch):
loss_dict = dict()
for idx, pair in enumerate(self.model_name_pairs):
out1 = predicts[pair[0]]
out2 = predicts[pair[1]]
attention_mask = batch[2]
if self.key is not None:
out1 = out1[self.key]
out2 = out2[self.key]
if attention_mask is not None:
max_len = attention_mask.shape[-1]
out1 = out1[:, :max_len]
out2 = out2[:, :max_len]
out1 = out1.reshape([-1, out1.shape[-1]])
out2 = out2.reshape([-1, out2.shape[-1]])
if attention_mask is not None:
active_output = attention_mask.reshape([-1, ]) == 1
out1 = out1[active_output]
out2 = out2[active_output]
loss = super().forward(out1, out2)
if isinstance(loss, dict):
for key in loss:
loss_dict["{}_{}nohu_{}".format(self.name, key,
idx)] = loss[key]
else:
loss_dict["{}_{}_{}_{}".format(self.name, pair[0], pair[1],
idx)] = loss
return loss_dict
......@@ -17,26 +17,30 @@ from __future__ import division
from __future__ import print_function
from paddle import nn
from ppocr.losses.basic_loss import DMLLoss
class VQASerTokenLayoutLMLoss(nn.Layer):
def __init__(self, num_classes):
def __init__(self, num_classes, key=None):
super().__init__()
self.loss_class = nn.CrossEntropyLoss()
self.num_classes = num_classes
self.ignore_index = self.loss_class.ignore_index
self.key = key
def forward(self, predicts, batch):
if isinstance(predicts, dict) and self.key is not None:
predicts = predicts[self.key]
labels = batch[5]
attention_mask = batch[2]
if attention_mask is not None:
active_loss = attention_mask.reshape([-1, ]) == 1
active_outputs = predicts.reshape(
active_output = predicts.reshape(
[-1, self.num_classes])[active_loss]
active_labels = labels.reshape([-1, ])[active_loss]
loss = self.loss_class(active_outputs, active_labels)
active_label = labels.reshape([-1, ])[active_loss]
loss = self.loss_class(active_output, active_label)
else:
loss = self.loss_class(
predicts.reshape([-1, self.num_classes]),
labels.reshape([-1, ]))
return {'loss': loss}
return {'loss': loss}
\ No newline at end of file
......@@ -19,6 +19,8 @@ from .rec_metric import RecMetric
from .det_metric import DetMetric
from .e2e_metric import E2EMetric
from .cls_metric import ClsMetric
from .vqa_token_ser_metric import VQASerTokenMetric
from .vqa_token_re_metric import VQAReTokenMetric
class DistillationMetric(object):
......
......@@ -73,28 +73,40 @@ class BaseModel(nn.Layer):
self.return_all_feats = config.get("return_all_feats", False)
def forward(self, x, data=None):
y = dict()
if self.use_transform:
x = self.transform(x)
x = self.backbone(x)
y["backbone_out"] = x
if isinstance(x, dict):
y.update(x)
else:
y["backbone_out"] = x
final_name = "backbone_out"
if self.use_neck:
x = self.neck(x)
y["neck_out"] = x
if isinstance(x, dict):
y.update(x)
else:
y["neck_out"] = x
final_name = "neck_out"
if self.use_head:
x = self.head(x, targets=data)
# for multi head, save ctc neck out for udml
if isinstance(x, dict) and 'ctc_neck' in x.keys():
y["neck_out"] = x["ctc_neck"]
y["head_out"] = x
elif isinstance(x, dict):
y.update(x)
else:
y["head_out"] = x
# for multi head, save ctc neck out for udml
if isinstance(x, dict) and 'ctc_neck' in x.keys():
y["neck_out"] = x["ctc_neck"]
y["head_out"] = x
elif isinstance(x, dict):
y.update(x)
else:
y["head_out"] = x
final_name = "head_out"
if self.return_all_feats:
if self.training:
return y
elif isinstance(x, dict):
return x
else:
return {"head_out": y["head_out"]}
return {final_name: x}
else:
return x
......@@ -22,13 +22,22 @@ from paddle import nn
from paddlenlp.transformers import LayoutXLMModel, LayoutXLMForTokenClassification, LayoutXLMForRelationExtraction
from paddlenlp.transformers import LayoutLMModel, LayoutLMForTokenClassification
from paddlenlp.transformers import LayoutLMv2Model, LayoutLMv2ForTokenClassification, LayoutLMv2ForRelationExtraction
from paddlenlp.transformers import AutoModel
__all__ = ["LayoutXLMForSer", 'LayoutLMForSer']
__all__ = ["LayoutXLMForSer", "LayoutLMForSer"]
pretrained_model_dict = {
LayoutXLMModel: 'layoutxlm-base-uncased',
LayoutLMModel: 'layoutlm-base-uncased',
LayoutLMv2Model: 'layoutlmv2-base-uncased'
LayoutXLMModel: {
"base": "layoutxlm-base-uncased",
"vi": "layoutxlm-wo-backbone-base-uncased",
},
LayoutLMModel: {
"base": "layoutlm-base-uncased",
},
LayoutLMv2Model: {
"base": "layoutlmv2-base-uncased",
"vi": "layoutlmv2-wo-backbone-base-uncased",
},
}
......@@ -36,42 +45,47 @@ class NLPBaseModel(nn.Layer):
def __init__(self,
base_model_class,
model_class,
type='ser',
mode="base",
type="ser",
pretrained=True,
checkpoints=None,
**kwargs):
super(NLPBaseModel, self).__init__()
if checkpoints is not None:
if checkpoints is not None: # load the trained model
self.model = model_class.from_pretrained(checkpoints)
elif isinstance(pretrained, (str, )) and os.path.exists(pretrained):
self.model = model_class.from_pretrained(pretrained)
else:
pretrained_model_name = pretrained_model_dict[base_model_class]
else: # load the pretrained-model
pretrained_model_name = pretrained_model_dict[base_model_class][
mode]
if pretrained is True:
base_model = base_model_class.from_pretrained(
pretrained_model_name)
else:
base_model = base_model_class(
**base_model_class.pretrained_init_configuration[
pretrained_model_name])
if type == 'ser':
base_model = base_model_class.from_pretrained(pretrained)
if type == "ser":
self.model = model_class(
base_model, num_classes=kwargs['num_classes'], dropout=None)
base_model, num_classes=kwargs["num_classes"], dropout=None)
else:
self.model = model_class(base_model, dropout=None)
self.out_channels = 1
self.use_visual_backbone = True
class LayoutLMForSer(NLPBaseModel):
def __init__(self, num_classes, pretrained=True, checkpoints=None,
def __init__(self,
num_classes,
pretrained=True,
checkpoints=None,
mode="base",
**kwargs):
super(LayoutLMForSer, self).__init__(
LayoutLMModel,
LayoutLMForTokenClassification,
'ser',
mode,
"ser",
pretrained,
checkpoints,
num_classes=num_classes)
num_classes=num_classes, )
self.use_visual_backbone = False
def forward(self, x):
x = self.model(
......@@ -85,62 +99,92 @@ class LayoutLMForSer(NLPBaseModel):
class LayoutLMv2ForSer(NLPBaseModel):
def __init__(self, num_classes, pretrained=True, checkpoints=None,
def __init__(self,
num_classes,
pretrained=True,
checkpoints=None,
mode="base",
**kwargs):
super(LayoutLMv2ForSer, self).__init__(
LayoutLMv2Model,
LayoutLMv2ForTokenClassification,
'ser',
mode,
"ser",
pretrained,
checkpoints,
num_classes=num_classes)
self.use_visual_backbone = True
if hasattr(self.model.layoutlmv2, "use_visual_backbone"
) and self.model.layoutlmv2.use_visual_backbone is False:
self.use_visual_backbone = False
def forward(self, x):
if self.use_visual_backbone is True:
image = x[4]
else:
image = None
x = self.model(
input_ids=x[0],
bbox=x[1],
attention_mask=x[2],
token_type_ids=x[3],
image=x[4],
image=image,
position_ids=None,
head_mask=None,
labels=None)
if not self.training:
if self.training:
res = {"backbone_out": x[0]}
res.update(x[1])
return res
else:
return x
return x[0]
class LayoutXLMForSer(NLPBaseModel):
def __init__(self, num_classes, pretrained=True, checkpoints=None,
def __init__(self,
num_classes,
pretrained=True,
checkpoints=None,
mode="base",
**kwargs):
super(LayoutXLMForSer, self).__init__(
LayoutXLMModel,
LayoutXLMForTokenClassification,
'ser',
mode,
"ser",
pretrained,
checkpoints,
num_classes=num_classes)
self.use_visual_backbone = True
def forward(self, x):
if self.use_visual_backbone is True:
image = x[4]
else:
image = None
x = self.model(
input_ids=x[0],
bbox=x[1],
attention_mask=x[2],
token_type_ids=x[3],
image=x[4],
image=image,
position_ids=None,
head_mask=None,
labels=None)
if not self.training:
if self.training:
res = {"backbone_out": x[0]}
res.update(x[1])
return res
else:
return x
return x[0]
class LayoutLMv2ForRe(NLPBaseModel):
def __init__(self, pretrained=True, checkpoints=None, **kwargs):
super(LayoutLMv2ForRe, self).__init__(LayoutLMv2Model,
LayoutLMv2ForRelationExtraction,
're', pretrained, checkpoints)
def __init__(self, pretrained=True, checkpoints=None, mode="base",
**kwargs):
super(LayoutLMv2ForRe, self).__init__(
LayoutLMv2Model, LayoutLMv2ForRelationExtraction, mode, "re",
pretrained, checkpoints)
def forward(self, x):
x = self.model(
......@@ -158,18 +202,27 @@ class LayoutLMv2ForRe(NLPBaseModel):
class LayoutXLMForRe(NLPBaseModel):
def __init__(self, pretrained=True, checkpoints=None, **kwargs):
super(LayoutXLMForRe, self).__init__(LayoutXLMModel,
LayoutXLMForRelationExtraction,
're', pretrained, checkpoints)
def __init__(self, pretrained=True, checkpoints=None, mode="base",
**kwargs):
super(LayoutXLMForRe, self).__init__(
LayoutXLMModel, LayoutXLMForRelationExtraction, mode, "re",
pretrained, checkpoints)
self.use_visual_backbone = True
if hasattr(self.model.layoutxlm, "use_visual_backbone"
) and self.model.layoutxlm.use_visual_backbone is False:
self.use_visual_backbone = False
def forward(self, x):
if self.use_visual_backbone is True:
image = x[4]
else:
image = None
x = self.model(
input_ids=x[0],
bbox=x[1],
attention_mask=x[2],
token_type_ids=x[3],
image=x[4],
image=image,
position_ids=None,
head_mask=None,
labels=None,
......
......@@ -31,21 +31,38 @@ from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode, \
SPINLabelDecode
from .cls_postprocess import ClsPostProcess
from .pg_postprocess import PGPostProcess
from .vqa_token_ser_layoutlm_postprocess import VQASerTokenLayoutLMPostProcess
from .vqa_token_re_layoutlm_postprocess import VQAReTokenLayoutLMPostProcess
from .vqa_token_ser_layoutlm_postprocess import VQASerTokenLayoutLMPostProcess, DistillationSerPostProcess
from .vqa_token_re_layoutlm_postprocess import VQAReTokenLayoutLMPostProcess, DistillationRePostProcess
from .table_postprocess import TableMasterLabelDecode, TableLabelDecode
def build_post_process(config, global_config=None):
support_dict = [
'DBPostProcess', 'EASTPostProcess', 'SASTPostProcess', 'FCEPostProcess',
'CTCLabelDecode', 'AttnLabelDecode', 'ClsPostProcess', 'SRNLabelDecode',
'PGPostProcess', 'DistillationCTCLabelDecode', 'TableLabelDecode',
'DistillationDBPostProcess', 'NRTRLabelDecode', 'SARLabelDecode',
'SEEDLabelDecode', 'VQASerTokenLayoutLMPostProcess',
'VQAReTokenLayoutLMPostProcess', 'PRENLabelDecode',
'DistillationSARLabelDecode', 'ViTSTRLabelDecode', 'ABINetLabelDecode',
'TableMasterLabelDecode', 'SPINLabelDecode'
'DBPostProcess',
'EASTPostProcess',
'SASTPostProcess',
'FCEPostProcess',
'CTCLabelDecode',
'AttnLabelDecode',
'ClsPostProcess',
'SRNLabelDecode',
'PGPostProcess',
'DistillationCTCLabelDecode',
'TableLabelDecode',
'DistillationDBPostProcess',
'NRTRLabelDecode',
'SARLabelDecode',
'SEEDLabelDecode',
'VQASerTokenLayoutLMPostProcess',
'VQAReTokenLayoutLMPostProcess',
'PRENLabelDecode',
'DistillationSARLabelDecode',
'ViTSTRLabelDecode',
'ABINetLabelDecode',
'TableMasterLabelDecode',
'SPINLabelDecode',
'DistillationSerPostProcess',
'DistillationRePostProcess',
]
if config['name'] == 'PSEPostProcess':
......
......@@ -49,3 +49,25 @@ class VQAReTokenLayoutLMPostProcess(object):
result.append((ocr_info_head, ocr_info_tail))
results.append(result)
return results
class DistillationRePostProcess(VQAReTokenLayoutLMPostProcess):
"""
DistillationRePostProcess
"""
def __init__(self, model_name=["Student"], key=None, **kwargs):
super().__init__(**kwargs)
if not isinstance(model_name, list):
model_name = [model_name]
self.model_name = model_name
self.key = key
def __call__(self, preds, *args, **kwargs):
output = dict()
for name in self.model_name:
pred = preds[name]
if self.key is not None:
pred = pred[self.key]
output[name] = super().__call__(pred, *args, **kwargs)
return output
......@@ -93,3 +93,25 @@ class VQASerTokenLayoutLMPostProcess(object):
ocr_info[idx]["pred"] = self.id2label_map_for_show[int(pred_id)]
results.append(ocr_info)
return results
class DistillationSerPostProcess(VQASerTokenLayoutLMPostProcess):
"""
DistillationSerPostProcess
"""
def __init__(self, class_path, model_name=["Student"], key=None, **kwargs):
super().__init__(class_path, **kwargs)
if not isinstance(model_name, list):
model_name = [model_name]
self.model_name = model_name
self.key = key
def __call__(self, preds, batch=None, *args, **kwargs):
output = dict()
for name in self.model_name:
pred = preds[name]
if self.key is not None:
pred = pred[self.key]
output[name] = super().__call__(pred, batch=batch, *args, **kwargs)
return output
......@@ -55,6 +55,9 @@ def load_model(config, model, optimizer=None, model_type='det'):
best_model_dict = {}
if model_type == 'vqa':
# NOTE: for vqa model, resume training is not supported now
if config["Architecture"]["algorithm"] in ["Distillation"]:
return best_model_dict
checkpoints = config['Architecture']['Backbone']['checkpoints']
# load vqa method metric
if checkpoints:
......@@ -78,6 +81,7 @@ def load_model(config, model, optimizer=None, model_type='det'):
logger.warning(
"{}.pdopt is not exists, params of optimizer is not loaded".
format(checkpoints))
return best_model_dict
if checkpoints:
......@@ -166,15 +170,19 @@ def save_model(model,
"""
_mkdir_if_not_exist(model_path, logger)
model_prefix = os.path.join(model_path, prefix)
paddle.save(optimizer.state_dict(), model_prefix + '.pdopt')
if config['Architecture']["model_type"] != 'vqa':
paddle.save(optimizer.state_dict(), model_prefix + '.pdopt')
if config['Architecture']["model_type"] != 'vqa':
paddle.save(model.state_dict(), model_prefix + '.pdparams')
metric_prefix = model_prefix
else:
else: # for vqa system, we follow the save/load rules in NLP
if config['Global']['distributed']:
model._layers.backbone.model.save_pretrained(model_prefix)
arch = model._layers
else:
model.backbone.model.save_pretrained(model_prefix)
arch = model
if config["Architecture"]["algorithm"] in ["Distillation"]:
arch = arch.Student
arch.backbone.model.save_pretrained(model_prefix)
metric_prefix = os.path.join(model_prefix, 'metric')
# save metric and config
with open(metric_prefix + '.states', 'wb') as f:
......
......@@ -38,6 +38,7 @@ def init_args():
parser.add_argument("--ir_optim", type=str2bool, default=True)
parser.add_argument("--use_tensorrt", type=str2bool, default=False)
parser.add_argument("--min_subgraph_size", type=int, default=15)
parser.add_argument("--shape_info_filename", type=str, default=None)
parser.add_argument("--precision", type=str, default="fp32")
parser.add_argument("--gpu_mem", type=int, default=500)
......@@ -204,9 +205,18 @@ def create_predictor(args, mode, logger):
workspace_size=1 << 30,
precision_mode=precision,
max_batch_size=args.max_batch_size,
min_subgraph_size=args.min_subgraph_size,
min_subgraph_size=args.min_subgraph_size, # skip the minmum trt subgraph
use_calib_mode=False)
# skip the minmum trt subgraph
# collect shape
if args.shape_info_filename is not None:
if not os.path.exists(args.shape_info_filename):
config.collect_shape_range_info(args.shape_info_filename)
logger.info(f"collect dynamic shape info into : {args.shape_info_filename}")
else:
logger.info(f"dynamic shape info file( {args.shape_info_filename} ) already exists, not need to generate again.")
config.enable_tuned_tensorrt_dynamic_shape(args.shape_info_filename, True)
use_dynamic_shape = True
if mode == "det":
min_input_shape = {
......
......@@ -113,10 +113,13 @@ def make_input(ser_inputs, ser_results):
class SerRePredictor(object):
def __init__(self, config, ser_config):
global_config = config['Global']
if "infer_mode" in global_config:
ser_config["Global"]["infer_mode"] = global_config["infer_mode"]
self.ser_engine = SerPredictor(ser_config)
# init re model
global_config = config['Global']
# build post process
self.post_process_class = build_post_process(config['PostProcess'],
......@@ -130,8 +133,8 @@ class SerRePredictor(object):
self.model.eval()
def __call__(self, img_path):
ser_results, ser_inputs = self.ser_engine({'img_path': img_path})
def __call__(self, data):
ser_results, ser_inputs = self.ser_engine(data)
re_input, entity_idx_dict_batch = make_input(ser_inputs, ser_results)
preds = self.model(re_input)
post_result = self.post_process_class(
......@@ -173,18 +176,33 @@ if __name__ == '__main__':
ser_re_engine = SerRePredictor(config, ser_config)
infer_imgs = get_image_file_list(config['Global']['infer_img'])
if config["Global"].get("infer_mode", None) is False:
data_dir = config['Eval']['dataset']['data_dir']
with open(config['Global']['infer_img'], "rb") as f:
infer_imgs = f.readlines()
else:
infer_imgs = get_image_file_list(config['Global']['infer_img'])
with open(
os.path.join(config['Global']['save_res_path'],
"infer_results.txt"),
"w",
encoding='utf-8') as fout:
for idx, img_path in enumerate(infer_imgs):
for idx, info in enumerate(infer_imgs):
if config["Global"].get("infer_mode", None) is False:
data_line = info.decode('utf-8')
substr = data_line.strip("\n").split("\t")
img_path = os.path.join(data_dir, substr[0])
data = {'img_path': img_path, 'label': substr[1]}
else:
img_path = info
data = {'img_path': img_path}
save_img_path = os.path.join(
config['Global']['save_res_path'],
os.path.splitext(os.path.basename(img_path))[0] + "_ser_re.jpg")
result = ser_re_engine(img_path)
result = ser_re_engine(data)
result = result[0]
fout.write(img_path + "\t" + json.dumps(
{
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册